diff --git a/CMakeLists.txt b/CMakeLists.txt index 5615d3825..3c69ded92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,7 +244,7 @@ if (NOT USE_PRE_BUILT_NGRAPH) ExternalProject_Add( ext_ngraph GIT_REPOSITORY https://github.com/NervanaSystems/ngraph - GIT_TAG v0.25.0-rc.3 + GIT_TAG v0.25.1-rc.0 CMAKE_ARGS -DNGRAPH_DISTRIBUTED_ENABLE=${NGRAPH_DISTRIBUTED_ENABLE} -DNGRAPH_INSTALL_PREFIX=${NGRAPH_ARTIFACTS_DIR} diff --git a/README.md b/README.md index afd1949df..1f961154b 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ Once TensorFlow's dependencies are installed, clone the `ngraph-bridge` repo: git clone https://github.com/tensorflow/ngraph-bridge.git cd ngraph-bridge - git checkout v0.18.0-rc4 + git checkout v0.19.0-rc0 Run the following Python script to build TensorFlow, nGraph, and the bridge. Use Python 3.5: diff --git a/bazel/BUILD b/bazel/BUILD index e131484cd..90a0762ec 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -33,6 +33,7 @@ cc_library( "ngraph_bridge/ngraph_encapsulate_op.h", "ngraph_bridge/ngraph_freshness_tracker.h", "ngraph_bridge/ngraph_mark_for_clustering.h", + "ngraph_bridge/ngraph_pipelined_tensors.h", "ngraph_bridge/ngraph_rewrite_for_tracking.h", "ngraph_bridge/ngraph_timer.h", "ngraph_bridge/ngraph_utils.h", @@ -69,6 +70,7 @@ cc_binary( "ngraph_bridge/ngraph_encapsulate_op.cc", "ngraph_bridge/ngraph_freshness_tracker.cc", "ngraph_bridge/ngraph_mark_for_clustering.cc", + "ngraph_bridge/ngraph_pipelined_tensors.cc", "ngraph_bridge/ngraph_rewrite_for_tracking.cc", "ngraph_bridge/ngraph_tracked_variable.cc", "ngraph_bridge/ngraph_utils.cc", diff --git a/bazel/WORKSPACE b/bazel/WORKSPACE index a242b3e27..aa234e466 100644 --- a/bazel/WORKSPACE +++ b/bazel/WORKSPACE @@ -25,11 +25,11 @@ tf_configure( http_archive( name = "ngraph", build_file = "//:bazel/ngraph.BUILD", - sha256 = "0b0cbd617653552d219c05bf975acfbcac513061a7b04465a71db324a9d9d7e3", - strip_prefix = "ngraph-0.25.0-rc.3", + sha256 = "030a0c22a098a958e1856b7930bdd7d1694ec882b61de3108afa8e59ff960e42", + strip_prefix = "ngraph-0.25.1-rc.0", urls = [ - "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.25.0-rc.3.tar.gz", - "https://github.com/NervanaSystems/ngraph/archive/v0.25.0-rc.3.tar.gz" + "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.0.tar.gz", + "https://github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.0.tar.gz" ], ) diff --git a/build_ngtf.py b/build_ngtf.py index d42a70bc6..3f468a8df 100755 --- a/build_ngtf.py +++ b/build_ngtf.py @@ -53,7 +53,7 @@ def main(): ''' # Component versions - ngraph_version = "v0.25.0-rc.3" + ngraph_version = "v0.25.1-rc.0" tf_version = "v1.14.0" # Command line parser options diff --git a/ngraph_bridge/CMakeLists.txt b/ngraph_bridge/CMakeLists.txt index 8ef4d3f53..0ad8d4161 100644 --- a/ngraph_bridge/CMakeLists.txt +++ b/ngraph_bridge/CMakeLists.txt @@ -42,6 +42,7 @@ set(SRC ngraph_conversions.cc ngraph_deassign_clusters.cc ngraph_encapsulate_clusters.cc + ngraph_pipelined_tensors.cc ngraph_encapsulate_impl.cc ops/ngraph_ops.cc ngraph_encapsulate_op.cc diff --git a/ngraph_bridge/ngraph_encapsulate_impl.cc b/ngraph_bridge/ngraph_encapsulate_impl.cc index 81f0432ac..f7a8e9486 100644 --- a/ngraph_bridge/ngraph_encapsulate_impl.cc +++ b/ngraph_bridge/ngraph_encapsulate_impl.cc @@ -245,6 +245,7 @@ Status NGraphEncapsulateImpl::GetNgExecutable( Status NGraphEncapsulateImpl::AllocateNGInputTensors( const std::vector& tf_input_tensors, const std::shared_ptr& ng_exec, + const PipelinedTensorVector& inp_group_from_pipeline, ng::runtime::Backend* const op_backend, vector>& ng_inputs) { std::vector> input_copy_events; @@ -291,9 +292,10 @@ Status NGraphEncapsulateImpl::AllocateNGInputTensors( std::shared_ptr last_ng_tensor = input_caches[i].second; void* current_src_ptr = (void*)DMAHelper::base(&tf_input_tensors[i]); - std::shared_ptr current_ng_tensor = - GetCurrentNgTensor(current_src_ptr, last_src_ptr, last_ng_tensor, false, - ng_exec, op_backend, ng_element_type, ng_shape); + std::shared_ptr current_ng_tensor = GetCurrentNgTensor( + current_src_ptr, last_src_ptr, last_ng_tensor, false, ng_exec, + op_backend, ng_element_type, ng_shape, + m_executable_can_create_tensor ? inp_group_from_pipeline[i] : nullptr); bool is_cpu = m_op_backend_name == "CPU"; if (!is_cpu && current_ng_tensor->get_stale()) { @@ -340,6 +342,7 @@ Status NGraphEncapsulateImpl::AllocateNGInputTensors( Status NGraphEncapsulateImpl::AllocateNGOutputTensors( const std::vector& output_tensors, const std::shared_ptr& ng_exec, + const PipelinedTensorVector& out_group_from_pipeline, ng::runtime::Backend* const op_backend, vector>& ng_outputs) { std::vector>>& @@ -374,9 +377,10 @@ Status NGraphEncapsulateImpl::AllocateNGOutputTensors( NGRAPH_VLOG(4) << "NGraphEncapsulateImpl:: Output from non Variable Node"; #endif - current_ng_tensor = - GetCurrentNgTensor(current_dst_ptr, last_dst_ptr, last_ng_tensor, true, - ng_exec, op_backend, ng_element_type, ng_shape); + current_ng_tensor = GetCurrentNgTensor( + current_dst_ptr, last_dst_ptr, last_ng_tensor, true, ng_exec, + op_backend, ng_element_type, ng_shape, + m_executable_can_create_tensor ? out_group_from_pipeline[i] : nullptr); current_ng_tensor->set_stale(true); output_caches[i] = std::make_pair(current_dst_ptr, current_ng_tensor); @@ -393,7 +397,8 @@ std::shared_ptr NGraphEncapsulateImpl::GetCurrentNgTensor( const bool& output_tensor, const std::shared_ptr& ng_exec, ng::runtime::Backend* const op_backend, - const ng::element::Type& ng_element_type, const ng::Shape& ng_shape) { + const ng::element::Type& ng_element_type, const ng::Shape& ng_shape, + std::shared_ptr tensor_from_pipeline) { // NOTE: we assume that TF's pointers WILL change if it actually changes // values. ie, it will not reuse the same space if its rewritten it bool tf_tensor_has_changed = current_tf_ptr != last_tf_ptr; @@ -426,20 +431,72 @@ std::shared_ptr NGraphEncapsulateImpl::GetCurrentNgTensor( } // create a new ng tensor or use the last one std::shared_ptr current_ng_tensor; - if (need_new_tensor_creation) { - if (is_cpu) { - current_ng_tensor = - op_backend->create_tensor(ng_element_type, ng_shape, current_tf_ptr); + if (m_executable_can_create_tensor) { + current_ng_tensor = tensor_from_pipeline; + } else { + if (need_new_tensor_creation) { + if (is_cpu) { + current_ng_tensor = op_backend->create_tensor(ng_element_type, ng_shape, + current_tf_ptr); + } else { + current_ng_tensor = + op_backend->create_tensor(ng_element_type, ng_shape); + } } else { - current_ng_tensor = op_backend->create_tensor(ng_element_type, ng_shape); + current_ng_tensor = last_ng_tensor; } - } else { - current_ng_tensor = last_ng_tensor; } current_ng_tensor->set_stale(is_stale); return current_ng_tensor; } +Status NGraphEncapsulateImpl::CachePipelinedTensorIfNeeded( + std::shared_ptr ng_exec) { + if (!m_executable_can_create_tensor) { + return errors::Internal( + "CachePipelinedTensorIfNeeded called, but executable cannot create " + "tensors"); + } + auto itr = m_executable_pipelined_tensors_map.find(ng_exec); + if (itr == m_executable_pipelined_tensors_map.end()) { + // Create these pipelined ng tensors only if needed, else reuse from cache + size_t num_inputs = ng_exec->get_parameters().size(); + size_t num_outputs = ng_exec->get_results().size(); + PipelinedTensorMatrix pipelined_input_tensors(num_inputs); + PipelinedTensorMatrix pipelined_output_tensors(num_outputs); + for (size_t i = 0; i < num_inputs; i++) { + pipelined_input_tensors[i] = ng_exec->create_input_tensor(i, m_depth); + } + for (size_t i = 0; i < num_outputs; i++) { + pipelined_output_tensors[i] = ng_exec->create_output_tensor(i, m_depth); + } + m_executable_pipelined_tensors_map.insert( + {ng_exec, PipelinedTensorsStore(pipelined_input_tensors, + pipelined_output_tensors)}); + } + return Status::OK(); +} + +std::tuple +NGraphEncapsulateImpl::GetTensorsFromPipeline( + std::shared_ptr ng_exec) { + PipelinedTensorsStore pts = m_executable_pipelined_tensors_map.at(ng_exec); + + // TODO: do something about this spin lock + // get_tensors returns an index integer, that can be -1, 0, ... depth-1 + // If it returns -1, then it indicates there are no free groups of tensors + // or the pipeline is full. In that case, we need to wait, hence the while + std::tuple out_tpl; + while (true) { + out_tpl = pts.get_tensors(); + + if (std::get<0>(out_tpl) >= 0) { + break; + } + } + return out_tpl; +} + } // namespace ngraph_bridge } // namespace tensorflow \ No newline at end of file diff --git a/ngraph_bridge/ngraph_encapsulate_impl.h b/ngraph_bridge/ngraph_encapsulate_impl.h index a12fa1011..dcd758976 100644 --- a/ngraph_bridge/ngraph_encapsulate_impl.h +++ b/ngraph_bridge/ngraph_encapsulate_impl.h @@ -28,6 +28,7 @@ #include "logging/ngraph_log.h" #include "ngraph_bridge/ngraph_freshness_tracker.h" +#include "ngraph_bridge/ngraph_pipelined_tensors.h" namespace tensorflow { @@ -61,6 +62,7 @@ class NGraphEncapsulateImpl { Status AllocateNGInputTensors( const std::vector& tf_input_tensors, const std::shared_ptr& ng_exec, + const PipelinedTensorVector& inp_group_from_pipeline, ng::runtime::Backend* const op_backend, vector>& ng_inputs); @@ -69,6 +71,7 @@ class NGraphEncapsulateImpl { Status AllocateNGOutputTensors( const std::vector& tf_output_tensors, const std::shared_ptr& ng_exec, + const PipelinedTensorVector& out_group_from_pipeline, ng::runtime::Backend* const op_backend, vector>& ng_outputs); @@ -79,7 +82,8 @@ class NGraphEncapsulateImpl { const bool& output_tensor, const std::shared_ptr& ng_exec, ng::runtime::Backend* const op_backend, - const ng::element::Type& ng_element_type, const ng::Shape& ng_shape); + const ng::element::Type& ng_element_type, const ng::Shape& ng_shape, + std::shared_ptr tensor_from_pipeline); // Accessors(getters and setters) for the private data members of // NgraphEncapsulateImpl class @@ -185,6 +189,25 @@ class NGraphEncapsulateImpl { void SetName(string name) { m_name = name; } + void SetExecCanCreateTensor(bool b) { m_executable_can_create_tensor = b; } + + bool GetExecCanCreateTensor() { return m_executable_can_create_tensor; } + + void ClearNgExecPipelinedTensorMap() { + m_executable_pipelined_tensors_map.clear(); + } + + Status CachePipelinedTensorIfNeeded( + std::shared_ptr ng_exec); + + std::tuple + GetTensorsFromPipeline(std::shared_ptr ng_exec); + + void ReturnPipelinedTensors( + std::shared_ptr ng_exec, size_t idx) { + m_executable_pipelined_tensors_map.at(ng_exec).return_tensors(idx); + } + // TF Graph for the cluster Graph m_graph; @@ -219,6 +242,13 @@ class NGraphEncapsulateImpl { // A single instance of freshness_tracker is used across all // nGraphEncapsulateOp and nGraphVariable op NGraphFreshnessTracker* m_freshness_tracker; + + bool m_executable_can_create_tensor = false; + std::unordered_map, + PipelinedTensorsStore> + m_executable_pipelined_tensors_map; + + int m_depth{2}; // TODO make this settable }; } // namespace ngraph_bridge diff --git a/ngraph_bridge/ngraph_encapsulate_op.cc b/ngraph_bridge/ngraph_encapsulate_op.cc index 695a1f93c..640e1cd3b 100644 --- a/ngraph_bridge/ngraph_encapsulate_op.cc +++ b/ngraph_bridge/ngraph_encapsulate_op.cc @@ -44,6 +44,7 @@ #include "ngraph_bridge/ngraph_encapsulate_op.h" #include "ngraph_bridge/ngraph_freshness_tracker.h" #include "ngraph_bridge/ngraph_mark_for_clustering.h" +#include "ngraph_bridge/ngraph_pipelined_tensors.h" #include "ngraph_bridge/ngraph_timer.h" #include "ngraph_bridge/ngraph_utils.h" @@ -203,6 +204,13 @@ NGraphEncapsulateOp::NGraphEncapsulateOp(OpKernelConstruction* ctx) BackendManager::SetConfig(ng_encap_impl.GetOpBackend(), additional_attribute_map); + ng_encap_impl.SetExecCanCreateTensor( + BackendManager::GetBackend(ng_encap_impl.GetOpBackend()) + ->executable_can_create_tensors()); + NGRAPH_VLOG(5) << "Executable can " + << (ng_encap_impl.GetExecCanCreateTensor() ? "" : "not") + << " create tensors"; + event.Stop(); ngraph::Event::write_trace(event); } @@ -261,6 +269,7 @@ NGraphEncapsulateOp::~NGraphEncapsulateOp() { ng_encap_impl.ClearNgExecOutputCache(); ng_encap_impl.ClearNgExecMap(); ng_encap_impl.ClearNgFunctionMap(); + ng_encap_impl.ClearNgExecPipelinedTensorMap(); // Release the backend NGRAPH_VLOG(2) << "~NGraphEncapsulateOp():: ReleaseBackend"; @@ -319,6 +328,31 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) { Timer create_or_lookup_tensors; + int pipeline_idx = -1; + PipelinedTensorVector inp_group_from_pipeline; + PipelinedTensorVector out_group_from_pipeline; + if (ng_encap_impl.GetExecCanCreateTensor()) { + OP_REQUIRES_OK(ctx, ng_encap_impl.CachePipelinedTensorIfNeeded(ng_exec)); + // Cache must contain the ng_exec at this point + + try { + std::tie(pipeline_idx, inp_group_from_pipeline, out_group_from_pipeline) = + ng_encap_impl.GetTensorsFromPipeline(ng_exec); + } catch (const std::exception& exp) { + OP_REQUIRES( + ctx, false, + errors::Internal("Caught exception while getting pipelined tensors: ", + exp.what(), "\n")); + } + + if (pipeline_idx < 0) { + OP_REQUIRES(ctx, false, + errors::Internal("Expected GetTensorsFromPipeline to return " + "an index >= 0, but got ", + pipeline_idx)); + } + } + if (ng_encap_impl.GetNgraphFreshnessTracker() == nullptr) { auto creator = [](NGraphFreshnessTracker** tracker) { *tracker = new NGraphFreshnessTracker(); @@ -343,7 +377,8 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) { int ng_input_tensor_size_in_bytes = 0; OP_REQUIRES_OK(ctx, ng_encap_impl.AllocateNGInputTensors( - tf_input_tensors, ng_exec, op_backend, ng_inputs)); + tf_input_tensors, ng_exec, inp_group_from_pipeline, + op_backend, ng_inputs)); event_alloc_input.Stop(); @@ -384,7 +419,8 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) { } OP_REQUIRES_OK(ctx, ng_encap_impl.AllocateNGOutputTensors( - tf_output_tensors, ng_exec, op_backend, ng_outputs)); + tf_output_tensors, ng_exec, out_group_from_pipeline, + op_backend, ng_outputs)); auto output_caches = ng_encap_impl.GetNgExecOutputCacheMap(ng_exec); event_alloc_output.Stop(); @@ -659,6 +695,17 @@ void NGraphEncapsulateOp::Compute(OpKernelContext* ctx) { int time_copy_output_tensors_to_host = copy_output_tensors_to_host.ElapsedInMS(); + if (ng_encap_impl.GetExecCanCreateTensor()) { + try { + ng_encap_impl.ReturnPipelinedTensors(ng_exec, pipeline_idx); + } catch (const std::exception& exp) { + OP_REQUIRES(ctx, false, + errors::Internal( + "Caught exception while returning pipelined tensors: ", + exp.what(), "\n")); + } + } + NGRAPH_VLOG(4) << "NGraphEncapsulateOp::Compute done marking fresh for cluster " << ng_encap_impl.GetNgraphCluster(); diff --git a/ngraph_bridge/ngraph_pipelined_tensors.cc b/ngraph_bridge/ngraph_pipelined_tensors.cc new file mode 100644 index 000000000..573a1b120 --- /dev/null +++ b/ngraph_bridge/ngraph_pipelined_tensors.cc @@ -0,0 +1,141 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "ngraph_bridge/ngraph_pipelined_tensors.h" + +using namespace std; +namespace ng = ngraph; + +namespace tensorflow { + +namespace ngraph_bridge { + +IndexLibrary::IndexLibrary(size_t depth) : m_depth(depth) { + for (size_t i = 0; i < depth; i++) { + m_free_depth_indexes.insert(i); + } +} + +void IndexLibrary::return_index(size_t id) { + if (m_depth == 0) { + throw std::runtime_error( + "Depth=0, so no one should be calling return_index"); + } + if (id > m_depth - 1) { + throw std::runtime_error("Depth = " + to_string(m_depth) + + " but passed an index to return ( = " + + to_string(id) + "), which is too large"); + } + if (is_free(id)) { + throw std::runtime_error( + "Attempted to return index " + to_string(id) + + " but it is already present in the free indices set"); + } + + insert_to_free_set(id); +} + +int IndexLibrary::get_index() { + if (m_depth == 0) { + return -1; + } else { + std::lock_guard lock(m_mtx); + if (m_free_depth_indexes.size() == 0) { + return -1; + } else { + // Find and return the smallest free integer + int min_idx = m_depth; // nothing can be >= m_depth + for (auto i : m_free_depth_indexes) { + if (i < min_idx) { + min_idx = i; + } + } + if (min_idx >= m_depth || min_idx < 0) { + throw std::runtime_error( + "get_index can only return values between 0 to depth-1 from here, " + "but attempted to return " + + to_string(min_idx)); + } + m_free_depth_indexes.erase(min_idx); + return min_idx; + } + } +} + +void IndexLibrary::insert_to_free_set(size_t id) { + std::lock_guard lock(m_mtx); + m_free_depth_indexes.insert(id); +} + +bool IndexLibrary::is_free(size_t id) { + std::lock_guard lock(m_mtx); + if (id > m_depth - 1) { + throw std::runtime_error("Asked to check if id=" + to_string(id) + + " is free, but depth=" + to_string(m_depth)); + } + return m_free_depth_indexes.find(id) != m_free_depth_indexes.end(); +} + +PipelinedTensorsStore::PipelinedTensorsStore(PipelinedTensorMatrix in, + PipelinedTensorMatrix out) + : m_in_tensors(in), + m_out_tensors(out), + m_num_inputs(in.size()), + m_num_outputs(out.size()) { + // The executable could have no inputs or no outputs. + // Hence the if-else below to determine m_depth + bool has_inputs = in.size() > 0; + bool has_outputs = out.size() > 0; + if (has_inputs) { + if (has_outputs) { + auto m_depth_in = in[0].size(); + auto m_depth_out = out[0].size(); + // We assume that input and output depths are same + m_depth = std::min(m_depth_in, m_depth_out); + } else { + m_depth = in[0].size(); + } + } else { + if (has_outputs) { + m_depth = out[0].size(); + } else { + m_depth = 0; // The executable has no inputs and outputs + } + } + idx_lib = make_shared(m_depth); +} + +tuple +PipelinedTensorsStore::get_tensors() { + int i = idx_lib->get_index(); + return make_tuple(i, (i < 0 ? PipelinedTensorVector{} : get_group(true, i)), + (i < 0 ? PipelinedTensorVector{} : get_group(false, i))); +} + +void PipelinedTensorsStore::return_tensors(size_t id) { + idx_lib->return_index(id); +} + +PipelinedTensorVector PipelinedTensorsStore::get_group(bool is_input, + size_t i) { + PipelinedTensorVector group; + for (size_t idx = 0; idx < (is_input ? m_num_inputs : m_num_outputs); idx++) { + group.push_back((is_input ? m_in_tensors : m_out_tensors)[idx][i]); + } + return group; +} +} +} diff --git a/ngraph_bridge/ngraph_pipelined_tensors.h b/ngraph_bridge/ngraph_pipelined_tensors.h new file mode 100644 index 000000000..515cac25e --- /dev/null +++ b/ngraph_bridge/ngraph_pipelined_tensors.h @@ -0,0 +1,145 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef NGRAPH_TF_BRIDGE_PIPELINED_TENSORS_H_ +#define NGRAPH_TF_BRIDGE_PIPELINED_TENSORS_H_ +#pragma once + +#include "ngraph/event_tracing.hpp" +#include "ngraph/runtime/backend.hpp" + +// Consider an ng executable, which has a inputs and b outputs. Let d_input[i] +// be the depth of the pipeline for input i. Similarly d_output[j] is the depth +// of the pipeline for output j. + +// Simplifying assumptions about pipeline depths: for all 0 <= i < a, 0 <= j < +// b, d_input[i] == d_output[i] == d. Most likely, d = 2 + +// Pipelined tensors Matrix: When the executable is used to create tensors, it +// will +// create non-ragged matrices of a x d and b x d tensors. + +// Input Group m: A set of a input tensors that can be used to feed data to the +// executable. This represents the m'th column of the input pipelined tensor +// matrix defined above + +// Output Group n: A set of b input tensors that can be used to collect data +// from the executable. This represents the n'th column of the input pipelined +// tensor matrix defined above + +// Simplifying assumption: We assume m == n, that is we use the same pipeline +// depth index when using call() on an executable. Because of this assumption we +// can store the input and output pipelined tensor matrix in the same class +// object. If we decide we can relax this constraint, then we can split up this +// class into 2, one handling inputs, one for outputs. + +// To implement the above design, we use the class PipelinedTensorsStore. +// It acts as a store for 2 PipelinedTensorMatrix (input and output) and +// supports 2 public functions get_tensors and return_tensors +// get_tensors: get_tensors is used to get an index (representing the pipeline +// depth) +// and 2 PipelinedTensorVector (for inputs and outputs). +// Note that get_tensors can return -1 as the index to indicate that +// no tensors are available at the moment +// return_tensors: Once we are done using it, we call return_tensors +// with the checked out index from get_tensors to indicate to +// PipelinedTensorsStore +// that we are done using the tensors of that pipeline depth, +// and it can give it to other threads that request tensors. + +// PipelinedTensorsStore relies on IndexLibrary to be threadsafe. +// IndexLibrary manages a set of integers: 0,1,...depth-1 +// It supports 2 functions get_index and return_index +// get_index returns the smallest int from the set of free indices +// (it returns -1 if none are available) +// return_index accepts back a number that was checkedout earlier +// IndexLibrary can be used safely in a multithreaded scenario since +// the underlying store of free indices is locked by a mutex + +using namespace std; +namespace ng = ngraph; + +namespace tensorflow { + +namespace ngraph_bridge { + +typedef vector> PipelinedTensorVector; +typedef vector PipelinedTensorMatrix; + +// IndexLibrary is a class that accepts an unsigned int "depth". This means that +// this class now owns integers from 0, 1, 2, ... depth-1 + +// See sample usage in test/test_index_library.cpp +class IndexLibrary { + public: + IndexLibrary(size_t depth); + + // If available return the smallest free integer (0<=i m_free_depth_indexes; + size_t m_depth; + std::mutex m_mtx; // protects m_free_depth_indexes + + // insert id in m_free_depth_indexes + void insert_to_free_set(size_t id); + // check if id already exists in m_free_depth_indexes + bool is_free(size_t id); +}; + +class PipelinedTensorsStore { + public: + PipelinedTensorsStore(PipelinedTensorMatrix in, PipelinedTensorMatrix out); + + // returns a tuple of idx, and 2 vectors of ng tensors (input and output + // groups). If the idx is negative, then its an invalid group (because + // pipeline is filled right now) + tuple get_tensors(); + + // Return an integer that was checked out by get_tensors. + // This indicates that the tensors corresponding to depth=id in the pipeline + // are ready for reuse and can be returned when get_tensors is called again + void return_tensors(size_t id); + + private: + PipelinedTensorMatrix m_in_tensors; + PipelinedTensorMatrix m_out_tensors; + size_t m_depth; + size_t m_num_inputs; + size_t m_num_outputs; + shared_ptr idx_lib; + + // Get the i'th depth tensors for inputs if is_input is true, else for outputs + PipelinedTensorVector get_group(bool is_input, size_t i); +}; +} +} + +#endif // NGRAPH_TF_BRIDGE_PIPELINED_TENSORS_H_ diff --git a/ngraph_bridge/version.cc b/ngraph_bridge/version.cc index 5e0a62b28..335a2384d 100644 --- a/ngraph_bridge/version.cc +++ b/ngraph_bridge/version.cc @@ -24,7 +24,7 @@ // nGraph-TensorFlow bridge uses semantic versioning: see http://semver.org/ #define NG_TF_MAJOR_VERSION 0 -#define NG_TF_MINOR_VERSION 18 +#define NG_TF_MINOR_VERSION 19 #define NG_TF_PATCH_VERSION 0 // The version suffix is used for pre-release version numbers @@ -32,7 +32,7 @@ // candidate such as v0.7.0-rc0 // The code in master will always have the last released version number // with a suffix of '-master' -#define NG_TF_VERSION_SUFFIX "-rc4" +#define NG_TF_VERSION_SUFFIX "-rc0" #define VERSION_STR_HELPER(x) #x #define VERSION_STR(x) VERSION_STR_HELPER(x) diff --git a/python/setup.in.py b/python/setup.in.py index d9d03551f..a0a48ec68 100644 --- a/python/setup.in.py +++ b/python/setup.in.py @@ -59,7 +59,7 @@ def get_tag(self): setup( name='ngraph_tensorflow_bridge', - version='0.18.0rc4', + version='0.19.0rc0', description='Intel nGraph compiler and runtime for TensorFlow', long_description=long_description, long_description_content_type="text/markdown", diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9770cd8d6..c4b5c80a0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -52,6 +52,7 @@ set(SRC graph_rewrites/encapsulate_clusters_test.cc graph_rewrites/disable_ops_test.cc graph_rewrites/mark_for_clustering_test.cc + test_index_library.cpp test_utilities.cpp test_math_ops.cpp test_nn_ops.cpp diff --git a/test/encapsulate_op/encapsulate_op_test.cc b/test/encapsulate_op/encapsulate_op_test.cc index f8659a2d8..7cb550d8b 100644 --- a/test/encapsulate_op/encapsulate_op_test.cc +++ b/test/encapsulate_op/encapsulate_op_test.cc @@ -146,7 +146,7 @@ TEST(EncapsulateOp, AllocateNGInputTensors) { std::vector> ng_inputs; - ASSERT_OK(ng_encap_impl.AllocateNGInputTensors(input_tensors, ng_exec, + ASSERT_OK(ng_encap_impl.AllocateNGInputTensors(input_tensors, ng_exec, {}, op_backend, ng_inputs)); BackendManager::ReleaseBackend("CPU"); } @@ -187,7 +187,7 @@ TEST(EncapsulateOp, AllocateNGOutputTensors) { } std::vector> ng_outputs; - ASSERT_OK(ng_encap_impl.AllocateNGOutputTensors(output_tensors, ng_exec, + ASSERT_OK(ng_encap_impl.AllocateNGOutputTensors(output_tensors, ng_exec, {}, op_backend, ng_outputs)); BackendManager::ReleaseBackend("CPU"); diff --git a/test/test_index_library.cpp b/test/test_index_library.cpp new file mode 100644 index 000000000..40e8d62aa --- /dev/null +++ b/test/test_index_library.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* + * Copyright 2019 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include +#include +#include + +#include "gtest/gtest.h" + +#include "ngraph_bridge/ngraph_pipelined_tensors.h" + +using namespace std; +namespace ng = ngraph; + +namespace tensorflow { + +namespace ngraph_bridge { + +namespace testing { + +TEST(IndexLibrary, SingleThreadTest1) { + IndexLibrary idx_lib{3}; + // idx_lib contains {0, 1, 2}; + + int i0 = idx_lib.get_index(); + ASSERT_EQ(i0, 0); + // idx_lib contains {1, 2}; i0 = 0 checked out + + int i1 = idx_lib.get_index(); + ASSERT_EQ(i1, 1); + // idx_lib contains {2}; i0 = 0, i1 = 1 checked out + + idx_lib.return_index(i0); + // idx_lib contains {0, 2}; i1 = 1 checked out + + int i2 = idx_lib.get_index(); + ASSERT_EQ(i2, 0); + // idx_lib contains {2}; i1 = 1, i2 = 0 checked out + + int i3 = idx_lib.get_index(); + ASSERT_EQ(i3, 2); + // idx_lib contains {}; i1 = 1, i2 = 0, i3 = 2 checked out + + int i4 = idx_lib.get_index(); + ASSERT_EQ(i4, -1) + << "Expected index library to be empty, hence get_index should return -1"; + + // Try to return an invalid index + ASSERT_THROW(idx_lib.return_index(50), std::runtime_error); + + idx_lib.return_index(i1); + // idx_lib contains {1}; i2 = 0, i3 = 2 checked out + + // Try to return an index that is already checkedin/returned + ASSERT_THROW(idx_lib.return_index(i1), std::runtime_error); +} + +TEST(IndexLibrary, SingleThreadTest2) { + IndexLibrary idx_lib{0}; + + // Since it is an empty library it will always return -1 + ASSERT_EQ(idx_lib.get_index(), -1); +} + +// 2 threads run randomly and attempt to get and return indices from the same +// IndexLibrary 10 times. +// The test asserts if one of the threads managed to get an index i, then the +// current and other thread must not have that index i +TEST(IndexLibrary, MultiThreadTest) { + IndexLibrary idx_lib{5}; + + auto seed = static_cast(time(0)); + std::mt19937 gen(seed); + std::uniform_real_distribution<> dis(0, 1); + + vector>> checked_out_collections = { + make_shared>(), make_shared>()}; + + auto worker = [&idx_lib, &dis, &gen, &seed, + &checked_out_collections](size_t thread_id) { + shared_ptr> my_checked_out = checked_out_collections[thread_id]; + shared_ptr> other_checked_out = + checked_out_collections[1 - thread_id]; + int count_work = 0; + while (true) { + if (dis(gen) > 0.5) { + int i = idx_lib.get_index(); + if (i >= 0) { + ASSERT_TRUE(my_checked_out->find(i) == my_checked_out->end()) + << "Failure seed: " << seed; + my_checked_out->insert(i); + count_work++; + // No need to lock access to my_checked_out and other_checked_out + // There is an implicit lock in between them from idx_lib + ASSERT_TRUE(other_checked_out->find(i) == other_checked_out->end()) + << "Failure seed: " << seed << "\n"; + } + } else { + if (my_checked_out->begin() != my_checked_out->end()) { + int j = *(my_checked_out->begin()); + + idx_lib.return_index(j); + count_work++; + my_checked_out->erase(j); + } + } + // wait for 1 or 2 ms randomly + std::chrono::milliseconds timespan((dis(gen) > 0.5) ? 1 : 2); + std::this_thread::sleep_for(timespan); + if (count_work >= 10) { + break; + } + } + // In the end return all indices + while (my_checked_out->begin() != my_checked_out->end()) { + int j = *(my_checked_out->begin()); + idx_lib.return_index(j); + my_checked_out->erase(j); + } + }; + + std::thread thread0(worker, 0); + std::thread thread1(worker, 1); + + thread0.join(); + thread1.join(); +} +} +} +} diff --git a/test/tf_exec.cpp b/test/tf_exec.cpp index e145aa066..9cea111de 100644 --- a/test/tf_exec.cpp +++ b/test/tf_exec.cpp @@ -30,6 +30,7 @@ #include "ngraph_bridge/ngraph_builder.h" #include "ngraph_bridge/ngraph_utils.h" +#include "ngraph_bridge/version.h" #include "test/test_utilities.h" using namespace std; @@ -42,6 +43,101 @@ namespace testing { #define ASSERT_OK(x) ASSERT_EQ((x), ::tensorflow::Status::OK()); +Status LoadGraph(const string& graph_file_name, + std::unique_ptr* session, + const tensorflow::SessionOptions& options) { + tensorflow::GraphDef graph_def; + auto load_graph_status = + ReadTextProto(Env::Default(), graph_file_name, &graph_def); + if (!load_graph_status.ok()) { + return tensorflow::errors::NotFound("Failed to load compute graph at '", + graph_file_name, "'"); + } + session->reset(tensorflow::NewSession(options)); + Status session_create_status = (*session)->Create(graph_def); + if (!session_create_status.ok()) { + return session_create_status; + } + return Status::OK(); +} + +Status CreateSession(const string& graph_filename, const string& backend_name, + unique_ptr& session) { + SessionOptions options; + options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(OptimizerOptions_Level_L0); + options.config.mutable_graph_options() + ->mutable_rewrite_options() + ->set_constant_folding(RewriterConfig::OFF); + + if (ngraph_tf_is_grappler_enabled()) { + auto* custom_config = options.config.mutable_graph_options() + ->mutable_rewrite_options() + ->add_custom_optimizers(); + + custom_config->set_name("ngraph-optimizer"); + (*custom_config->mutable_parameter_map())["ngraph_backend"].set_s( + backend_name); + (*custom_config->mutable_parameter_map())["device_id"].set_s("0"); + + options.config.mutable_graph_options() + ->mutable_rewrite_options() + ->set_min_graph_nodes(-1); + + options.config.mutable_graph_options() + ->mutable_rewrite_options() + ->set_meta_optimizer_iterations(RewriterConfig::ONE); + } + + // Load the network + Status load_graph_status = LoadGraph(graph_filename, &session, options); + return load_graph_status; +} + +TEST(tf_exec, SingleGraphOn2Threads) { + string graph_name = "test_axpy.pbtxt"; + vector backends{"CPU", "INTERPRETER"}; + for (auto be : backends) { + unique_ptr session; + ASSERT_OK(CreateSession(graph_name, be, session)); + + auto worker = [&session](size_t thread_id) { + string inp_tensor_name_0{"x"}; + string inp_tensor_name_1{"y"}; + string out_tensor_name{"add"}; + std::vector out_tensor_vals; + + for (int i = 0; i < 10; i++) { + Tensor inp_tensor_val(tensorflow::DT_FLOAT, + tensorflow::TensorShape({2, 3})); + vector in_vals(6, float(i)); + AssignInputValues(inp_tensor_val, in_vals); + Tensor out_tensor_expected_val(tensorflow::DT_FLOAT, + tensorflow::TensorShape({2, 3})); + vector out_vals(6, 6.0 * float(i)); + AssignInputValues(out_tensor_expected_val, out_vals); + + std::vector> inputs = { + {inp_tensor_name_0, inp_tensor_val}, + {inp_tensor_name_1, inp_tensor_val}}; + + NGRAPH_VLOG(5) << "thread_id: " << thread_id << " started: " << i; + ASSERT_OK( + session->Run(inputs, {out_tensor_name}, {}, &out_tensor_vals)); + NGRAPH_VLOG(5) << "thread_id: " << thread_id << " finished: " << i; + Compare(out_tensor_vals, {out_tensor_expected_val}); + } + }; + + std::thread thread0(worker, 0); + std::thread thread1(worker, 1); + + thread0.join(); + thread1.join(); + } +} + TEST(tf_exec, hello_world) { Scope root = Scope::NewRootScope(); @@ -416,7 +512,7 @@ TEST(tf_exec, DISABLED_Op_L2Loss) { Scope root_ngraph = root.NewSubScope("sub_scope_ngraph"); root_ngraph = root_ngraph.WithDevice("/device:NGRAPH:0"); - std::vector > input_sizes; + std::vector> input_sizes; input_sizes.push_back({2, 3, 4}); input_sizes.push_back({0}); @@ -447,7 +543,7 @@ TEST(tf_exec, DISABLED_Op_Unpack) { root = root.WithDevice("/device:CPU:0"); root_ngraph = root_ngraph.WithDevice("/device:NGRAPH:0"); - std::vector > input_sizes; + std::vector> input_sizes; int input_rank = 3;