Skip to content

Commit

Permalink
Sarkars/aot (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
sayantan-nervana authored Aug 22, 2019
1 parent c0760d5 commit d7a0b4a
Show file tree
Hide file tree
Showing 24 changed files with 1,751 additions and 130 deletions.
2 changes: 2 additions & 0 deletions bazel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
"ngraph_bridge/ngraph_encapsulate_op.h",
"ngraph_bridge/ngraph_freshness_tracker.h",
"ngraph_bridge/ngraph_mark_for_clustering.h",
"ngraph_bridge/ngraph_partial_shapes.h",
"ngraph_bridge/ngraph_pipelined_tensors.h",
"ngraph_bridge/ngraph_rewrite_for_tracking.h",
"ngraph_bridge/ngraph_timer.h",
Expand Down Expand Up @@ -70,6 +71,7 @@ cc_binary(
"ngraph_bridge/ngraph_encapsulate_op.cc",
"ngraph_bridge/ngraph_freshness_tracker.cc",
"ngraph_bridge/ngraph_mark_for_clustering.cc",
"ngraph_bridge/ngraph_partial_shapes.cc",
"ngraph_bridge/ngraph_pipelined_tensors.cc",
"ngraph_bridge/ngraph_rewrite_for_tracking.cc",
"ngraph_bridge/ngraph_tracked_variable.cc",
Expand Down
1 change: 1 addition & 0 deletions ngraph_bridge/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ set(SRC
ngraph_encapsulate_op.cc
ngraph_freshness_tracker.cc
ngraph_mark_for_clustering.cc
ngraph_partial_shapes.cc
ngraph_rewrite_for_tracking.cc
ngraph_rewrite_pass.cc
ngraph_tracked_variable.cc
Expand Down
2 changes: 1 addition & 1 deletion ngraph_bridge/enable_variable_ops/ngraph_rewrite_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class NGraphEncapsulationPass : public NGraphRewritePass {
// 4. Encapsulate clusters then, if requested, dump the graphs.
FunctionDefLibrary* fdeflib_new = new FunctionDefLibrary();
TF_RETURN_IF_ERROR(EncapsulateClusters(options.graph->get(), idx,
fdeflib_new, config_map));
fdeflib_new, config_map, {0, {}}));
// TODO: not using fdeflib_new in this path. Only grappler path uses it
free(fdeflib_new);
if (DumpEncapsulatedGraphs()) {
Expand Down
40 changes: 35 additions & 5 deletions ngraph_bridge/grappler/ngraph_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,42 @@ Status NgraphOptimizer::Init(
config_backend_name = params.at("ngraph_backend").s();
config_device_id = params.at("device_id").s();
NGRAPH_VLOG(3) << "Backend name from config: " << config_backend_name;
std::set<ShapeHintMap> shape_hints;
// typedef std::map<std::string, std::vector<int>> ShapeHintMap;
for (auto i : params) {
if (i.first != "ngraph_backend") {
config_map[(i.first == "device_id" ? "" : "_") + std::string("ngraph_") +
i.first] = i.second.s();
NGRAPH_VLOG(3) << "Attribute: " << i.first
<< " Value: " << config_map["_ngraph_" + i.first];
// TODO: slightly hacky. The bridge reserves the right to use optional
// attributes whose names start with shape_hint
if (i.first.rfind("shape_hint", 0) != 0) {
config_map[(i.first == "device_id" ? "" : "_") +
std::string("ngraph_") + i.first] = i.second.s();
NGRAPH_VLOG(3) << "Attribute: " << i.first
<< " Value: " << config_map["_ngraph_" + i.first];
} else {
ShapeHintMap hint;
for (auto k : i.second.func().attr().at("hint_body").func().attr()) {
vector<int> full_or_partial_shape;
for (auto dim : k.second.tensor().int_val()) {
full_or_partial_shape.push_back(dim);
}
hint[k.first] = full_or_partial_shape;
}
shape_hints.insert(hint);
}
}
}
auto itr = params.find("aot_requested");
bool do_aot = false;
if (itr != params.end()) {
do_aot = itr->second.s() == "1";
}
if (!do_aot && shape_hints.size() > 0) {
return errors::Internal(
"Did not requested AOT, but passed shape hints. Please request to use "
"shape hints (by using --precompile in tf2ngraph.py), or if AOT is not "
"desired then do not pass shape hints");
}
aot_info = make_pair(do_aot, shape_hints);
return Status::OK();
}

Expand Down Expand Up @@ -230,7 +258,9 @@ Status NgraphOptimizer::Optimize(tensorflow::grappler::Cluster* cluster,

// 4. Encapsulate clusters then, if requested, dump the graphs.
FunctionDefLibrary* fdeflib_new = new FunctionDefLibrary();
TF_RETURN_IF_ERROR(EncapsulateClusters(&graph, idx, fdeflib_new, config_map));
TF_RETURN_IF_ERROR(
// TODO: right now _ngraph_aot_requested is passed along in config_map.
EncapsulateClusters(&graph, idx, fdeflib_new, config_map, aot_info));
if (DumpEncapsulatedGraphs()) {
DumpGraphs(graph, idx, "encapsulated", "Graph with Clusters Encapsulated");
}
Expand Down
1 change: 1 addition & 0 deletions ngraph_bridge/grappler/ngraph_optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class NgraphOptimizer : public tensorflow::grappler::CustomGraphOptimizer {

static int s_serial_counter GUARDED_BY(s_serial_counter_mutex);
static mutex s_serial_counter_mutex;
AOTInfo aot_info;
};

int NgraphOptimizer::s_serial_counter = 0;
Expand Down
Loading

0 comments on commit d7a0b4a

Please sign in to comment.