Skip to content

Commit

Permalink
Sarkars/upgrade r19 (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
sayantan-nervana authored Sep 19, 2019
1 parent 218990b commit 3f45c87
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 46 deletions.
4 changes: 4 additions & 0 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "ngraph/op/util/logical_reduction.hpp"

#include "logging/ngraph_log.h"
#include "ngraph_bridge/ngraph_api.h"
#include "ngraph_bridge/ngraph_backend_manager.h"
#include "ngraph_bridge/ngraph_builder.h"
#include "ngraph_bridge/ngraph_conversions.h"
Expand Down Expand Up @@ -97,6 +98,9 @@ std::shared_ptr<TOpType> ConstructNgNode(const std::string& op_name,
auto ng_node = std::make_shared<TOpType>(std::forward<TArg>(Args)...);
ng_node->set_friendly_name(op_name);
ng_node->add_provenance_tag(op_name);
if (config::IsLoggingPlacement()) {
cout << "TF_to_NG: " << op_name << " --> " << ng_node->get_name() << "\n";
}
return ng_node;
}

Expand Down
14 changes: 9 additions & 5 deletions ngraph_bridge/ngraph_encapsulate_clusters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,6 @@ Status EncapsulateClusters(
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, *gdef_for_current_encapsulate,
&graph_for_current_encapsulate));
TF_RETURN_IF_ERROR(Builder::TranslateGraph(
input_shapes, static_input_map, &graph_for_current_encapsulate,
ng_function));
string serialized_ngfunc(ngraph::serialize(ng_function, 4));

// get backend.
// TODO: Note that this is code duplication of some stuff present
Expand All @@ -793,6 +789,14 @@ Status EncapsulateClusters(
}
TF_RETURN_IF_ERROR(BackendManager::CreateBackend(
op_backend_name)); // Created a backend here. must free it
// TranslateGraph must be called AFTER CreateBackend because some TF
// ops like CNMS and gather use backend specific nodes
TF_RETURN_IF_ERROR(Builder::TranslateGraph(
input_shapes, static_input_map, &graph_for_current_encapsulate,
ng_function));
int json_indentation = 4;
string serialized_ngfunc(
ngraph::serialize(ng_function, json_indentation));
std::unordered_map<std::string, std::string> additional_attribute_map;
for (auto itr : node->attrs()) {
// Find the optional attributes to be sent to the backend.
Expand All @@ -804,7 +808,7 @@ Status EncapsulateClusters(
// For e.g. _ngraph_ice_cores --> ice_cores
if (itr.first.find("_ngraph_") != std::string::npos) {
// leave out _ngraph_aot_requested
if (itr.first.find("_ngraph_aot_requested") !=
if (itr.first.find("_ngraph_aot_requested") ==
std::string::npos) {
additional_attribute_map.insert(
{itr.first.substr(strlen("_ngraph_")), itr.second.s()});
Expand Down
33 changes: 19 additions & 14 deletions ngraph_bridge/ngraph_encapsulate_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,32 +134,32 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
MemoryProfile(vm0, rss0);

NGRAPH_VLOG(1) << "Compilation cache miss: " << m_name;
string serialized_ng_func;
if (!m_do_aot) {
TF_RETURN_IF_ERROR(Builder::TranslateGraph(input_shapes, static_input_map,
&m_graph, ng_function));
ng_function->set_friendly_name(m_name);
int json_indentation = 4;
serialized_ng_func = ngraph::serialize(ng_function, json_indentation);
} else {
auto itr = m_aot_functions.find(signature);
if (itr == m_aot_functions.end()) {
return errors::Internal(
"Expected to find AOT precompiled ng function of signature: ",
signature);
}
ng_function = ng::deserialize(itr->second);
serialized_ng_func = itr->second;
}

auto function_size = ng_function->get_graph_size() / 1024; // kb unit

// Serialize to nGraph if needed
if (std::getenv("NGRAPH_ENABLE_SERIALIZE") != nullptr) {
std::string file_name = "tf_function_" + m_name + ".json";
NgraphSerialize("tf_function_" + m_name + ".json", ng_function);
StringToFile("tf_function_" + m_name + ".json", serialized_ng_func);
#if defined NGRAPH_DISTRIBUTED
int rank_id;
rank_id = ng::get_distributed_interface()->get_rank();
NgraphSerialize(
"tf_function_" + m_name + "_" + to_string(rank_id) + ".json",
ng_function);
StringToFile("tf_function_" + m_name + "_" + to_string(rank_id) + ".json",
serialized_ng_func);
#endif
}
// Evict the cache if the number of elements exceeds the limit
Expand All @@ -172,7 +172,7 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
int input_tensors_bytes_free = 0;
evicted_ng_exec = m_ng_exec_map[m_lru.back()];
m_ng_exec_map.erase(m_lru.back());
m_ng_function_map.erase(evicted_ng_exec);
m_serialized_ng_function_map.erase(evicted_ng_exec);

// Call delete function here for the erased func
op_backend->remove_compiled_function(evicted_ng_exec);
Expand Down Expand Up @@ -222,12 +222,12 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
}
} catch (const std::exception& exp) {
BackendManager::UnlockBackend(m_op_backend_name);
NgraphSerialize("tf_function_error_" + m_name + ".json", ng_function);
StringToFile("tf_function_error_" + m_name + ".json", serialized_ng_func);
return errors::Internal("Caught exception while compiling op_backend: ",
exp.what(), "\n");
} catch (...) {
BackendManager::UnlockBackend(m_op_backend_name);
NgraphSerialize("tf_function_error_" + m_name + ".json", ng_function);
StringToFile("tf_function_error_" + m_name + ".json", serialized_ng_func);
return errors::Internal("Error in compiling op_backend\n");
}
BackendManager::UnlockBackend(m_op_backend_name);
Expand All @@ -236,7 +236,7 @@ Status NGraphEncapsulateImpl::GetNgExecutable(

SetNgExecMap(signature, ng_exec);
// caching ng_function to serialize to ngraph if needed
SetNgFunctionMap(ng_exec, ng_function);
m_serialized_ng_function_map[ng_exec] = serialized_ng_func;

m_lru.push_front(signature);
// Memory after
Expand All @@ -245,9 +245,8 @@ Status NGraphEncapsulateImpl::GetNgExecutable(
auto delta_res_mem = rss - rss0;
NGRAPH_VLOG(1) << "NGRAPH_TF_CACHE_PROFILE: OP_ID: " << my_instance_id
<< " Cache length: " << m_ng_exec_map.size()
<< " Cluster: " << m_name << " Delta VM: " << delta_vm_mem
<< " Delta RSS: " << delta_res_mem
<< " Function size: " << function_size
<< " Cluster: " << m_name << " Delta VM: " << delta_vm_mem
<< " Delta RSS: " << delta_res_mem
<< " KB Total RSS: " << rss / (1024 * 1024) << " GB "
<< " VM: " << vm / (1024 * 1024) << " GB" << endl;
} // end of input signature not found in m_ng_exec_map
Expand Down Expand Up @@ -582,6 +581,12 @@ NGraphEncapsulateImpl::GetTensorsFromPipeline(
return out_tpl;
}

void NGraphEncapsulateImpl::DumpNgFunction(
const string& file_name,
std::shared_ptr<ngraph::runtime::Executable> ng_exec) {
StringToFile(file_name, m_serialized_ng_function_map[ng_exec]);
}

} // namespace ngraph_bridge

} // namespace tensorflow
25 changes: 9 additions & 16 deletions ngraph_bridge/ngraph_encapsulate_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class NGraphEncapsulateImpl {
const ng::element::Type& ng_element_type, const ng::Shape& ng_shape,
std::shared_ptr<ng::runtime::Tensor> tensor_from_pipeline);

void DumpNgFunction(const string&,
std::shared_ptr<ngraph::runtime::Executable>);

// Accessors(getters and setters) for the private data members of
// NgraphEncapsulateImpl class
// needed by
Expand Down Expand Up @@ -148,19 +151,6 @@ class NGraphEncapsulateImpl {

void ClearNgExecMap() { m_ng_exec_map.clear(); }

std::unordered_map<std::shared_ptr<ngraph::runtime::Executable>,
std::shared_ptr<ngraph::Function>>
GetNgFunctionMap() {
return m_ng_function_map;
}

void SetNgFunctionMap(
const std::shared_ptr<ngraph::runtime::Executable>& exec,
const std::shared_ptr<ngraph::Function>& function) {
m_ng_function_map[exec] = function;
}

void ClearNgFunctionMap() { m_ng_function_map.clear(); }
// TODO:sindhu have another get function for output_cache which is only
// readable
std::vector<std::pair<void*, shared_ptr<ng::runtime::Tensor>>>&
Expand All @@ -179,6 +169,10 @@ class NGraphEncapsulateImpl {

void ClearNgExecOutputCache() { m_ng_exec_output_cache_map.clear(); }

void ClearNgExecSerializedFunctionCache() {
m_serialized_ng_function_map.clear();
}

NGraphFreshnessTracker* GetNgraphFreshnessTracker() {
return m_freshness_tracker;
}
Expand Down Expand Up @@ -236,9 +230,8 @@ class NGraphEncapsulateImpl {
// ng_function, ng_executable, Output and Input Cache maps
std::unordered_map<std::string, std::shared_ptr<ngraph::runtime::Executable>>
m_ng_exec_map;
std::unordered_map<std::shared_ptr<ngraph::runtime::Executable>,
std::shared_ptr<ngraph::Function>>
m_ng_function_map;
std::unordered_map<std::shared_ptr<ngraph::runtime::Executable>, std::string>
m_serialized_ng_function_map;

NgFunctionIOCache m_ng_exec_input_cache_map;
NgFunctionIOCache m_ng_exec_output_cache_map;
Expand Down
13 changes: 5 additions & 8 deletions ngraph_bridge/ngraph_encapsulate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ NGraphEncapsulateOp::~NGraphEncapsulateOp() {
ng_encap_impl.ClearNgExecInputCache();
ng_encap_impl.ClearNgExecOutputCache();
ng_encap_impl.ClearNgExecMap();
ng_encap_impl.ClearNgFunctionMap();
ng_encap_impl.ClearNgExecPipelinedTensorMap();
ng_encap_impl.ClearNgExecSerializedFunctionCache();

// Release the backend
NGRAPH_VLOG(2) << "~NGraphEncapsulateOp():: ReleaseBackend";
Expand Down Expand Up @@ -291,7 +291,6 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {

std::vector<TensorShape> input_shapes;
std::vector<const Tensor*> static_input_map;
std::shared_ptr<ngraph::Function> ng_function;
std::shared_ptr<ngraph::runtime::Executable> ng_exec;
ng::runtime::Backend* op_backend;

Expand Down Expand Up @@ -526,19 +525,17 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) {
try {
ng_exec->call(ng_outputs, ng_inputs);
} catch (const std::exception& exp) {
ng_function = ng_encap_impl.GetNgFunctionMap()[ng_exec];
BackendManager::UnlockBackend(ng_encap_impl.GetOpBackend());
NgraphSerialize("tf_function_error_" + ctx->op_kernel().name() + ".json",
ng_function);
ng_encap_impl.DumpNgFunction(
"tf_function_error_" + ctx->op_kernel().name() + ".json", ng_exec);
OP_REQUIRES(ctx, false,
errors::Internal(
"Caught exception while executing nGraph computation: ",
exp.what(), "\n"));
} catch (...) {
ng_function = ng_encap_impl.GetNgFunctionMap()[ng_exec];
BackendManager::UnlockBackend(ng_encap_impl.GetOpBackend());
NgraphSerialize("tf_function_error_" + ctx->op_kernel().name() + ".json",
ng_function);
ng_encap_impl.DumpNgFunction(
"tf_function_error_" + ctx->op_kernel().name() + ".json", ng_exec);
OP_REQUIRES(
ctx, false,
errors::Internal("Error in executing the nGraph computation\n"));
Expand Down
8 changes: 6 additions & 2 deletions ngraph_bridge/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,16 @@ Status CheckAxisDimInRange(std::vector<int64> axes, size_t rank) {
void NgraphSerialize(const std::string& file_name,
const std::shared_ptr<ngraph::Function>& ng_function) {
NGRAPH_VLOG(0) << "Serializing graph to: " << file_name << std::endl;
std::string js = ngraph::serialize(ng_function, 4);
int json_indentation = 4;
StringToFile(file_name, ngraph::serialize(ng_function, json_indentation));
}

void StringToFile(const std::string& file_name, const std::string& contents) {
std::ofstream f;
f.exceptions(std::ofstream::failbit | std::ofstream::badbit);
try {
f.open(file_name);
f << js;
f << contents;
f.close();
} catch (std::ofstream::failure& e) {
NGRAPH_VLOG(0) << "Exception opening/closing file " << file_name
Expand Down
2 changes: 2 additions & 0 deletions ngraph_bridge/ngraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ Status CheckAxisDimInRange(std::vector<int64> axes, size_t rank);
void NgraphSerialize(const std::string&,
const std::shared_ptr<ngraph::Function>&);

void StringToFile(const std::string&, const std::string&);

// Collect the total memory usage through /proc/self/stat
void MemoryProfile(long&, long&);

Expand Down
2 changes: 1 addition & 1 deletion test/model_level_tests/models/MLP/getting_repo_ready.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pip install -U keras
pip install -U keras==2.2.5
1 change: 1 addition & 0 deletions test/model_level_tests/models/MLP/repo.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
https://github.com/keras-team/keras.git
2.2.5
62 changes: 62 additions & 0 deletions test/python/test_ngraph_serialize_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# ==============================================================================
# Copyright 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytest for a simple run on model testing framework
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pytest
import platform
import os

import tensorflow as tf
import numpy as np
import re

from common import NgraphTest
import ngraph_bridge


class TestNgraphSerialize(NgraphTest):

def test_ng_serialize_to_json(self):
initial_contents = set(os.listdir())
xshape = (3, 4, 5)
x = tf.placeholder(tf.float32, shape=xshape)
out = tf.nn.l2_loss(tf.abs(x))
values = np.random.rand(*xshape)

config = ngraph_bridge.update_config(tf.ConfigProto())
ngraph_enable_serialize = os.environ.pop('NGRAPH_ENABLE_SERIALIZE',
None)
os.environ['NGRAPH_ENABLE_SERIALIZE'] = '1'
ngraph_bridge.enable()
with tf.Session(config=config) as sess:
out = sess.run((out), feed_dict={x: values})
os.environ.pop('NGRAPH_ENABLE_SERIALIZE', None)
if ngraph_enable_serialize is not None:
os.environ['NGRAPH_ENABLE_SERIALIZE'] = \
ngraph_enable_serialize

final_contents = set(os.listdir())
assert (len(final_contents) - len(initial_contents) == 1)
new_files = final_contents.difference(initial_contents)
flname = new_files.pop()
assert (flname.startswith('tf_function_') and flname.endswith('json'))
os.remove(flname)

0 comments on commit 3f45c87

Please sign in to comment.