Skip to content

Commit

Permalink
Sindhu/more ops (scatternd) (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
sayantan-nervana authored Nov 19, 2019
1 parent 50fc3c9 commit bd7a76a
Show file tree
Hide file tree
Showing 10 changed files with 337 additions and 11 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ if (NOT USE_PRE_BUILT_NGRAPH)
ExternalProject_Add(
ext_ngraph
GIT_REPOSITORY https://github.com/NervanaSystems/ngraph
GIT_TAG v0.25.1-rc.9
GIT_TAG v0.25.1-rc.10
CMAKE_ARGS
-DNGRAPH_DISTRIBUTED_ENABLE=${NGRAPH_DISTRIBUTED_ENABLE}
-DNGRAPH_INSTALL_PREFIX=${NGRAPH_ARTIFACTS_DIR}
Expand Down
8 changes: 4 additions & 4 deletions bazel/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ tf_workspace(path_prefix = "", tf_repo_name = "org_tensorflow")
http_archive(
name = "ngraph",
build_file = "//:bazel/ngraph.BUILD",
sha256 = "bb9effe71818d932dfa1f024051ba716267ca1b2fc35471faf780f7bfdf6c9b3",
strip_prefix = "ngraph-0.25.1-rc.9",
sha256 = "110b4f40033ea28425a3b80bc0cb81a0301f5cc0bb92c200f113dfeb9533aec7",
strip_prefix = "ngraph-0.25.1-rc.10",
urls = [
"https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.9.tar.gz",
"https://github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.9.tar.gz"
"https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.10.tar.gz",
"https://github.com/NervanaSystems/ngraph/archive/v0.25.1-rc.10.tar.gz"
],
)

Expand Down
6 changes: 3 additions & 3 deletions bazel/ngraph.BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ cc_library(
"-fstack-protector-all",
'-D SHARED_LIB_PREFIX=\\"lib\\"',
'-D SHARED_LIB_SUFFIX=\\".so\\"',
'-D NGRAPH_VERSION=\\"v0.25.1-rc.9\\"',
'-D NGRAPH_VERSION=\\"v0.25.1-rc.10\\"',
"-D NGRAPH_DEX_ONLY",
'-D PROJECT_ROOT_DIR=\\"\\"',
'-D NGRAPH_STATIC_LIB_ENABLE'
Expand Down Expand Up @@ -118,7 +118,7 @@ cc_library(
"-fstack-protector-all",
'-D SHARED_LIB_PREFIX=\\"lib\\"',
'-D SHARED_LIB_SUFFIX=\\".so\\"',
'-D NGRAPH_VERSION=\\"v0.25.1-rc.9\\"',
'-D NGRAPH_VERSION=\\"v0.25.1-rc.10\\"',
"-D NGRAPH_DEX_ONLY",
'-D PROJECT_ROOT_DIR=\\"\\"',
] + CXX_ABI,
Expand Down Expand Up @@ -269,7 +269,7 @@ cc_library(
"-fstack-protector-all",
'-D SHARED_LIB_PREFIX=\\"lib\\"',
'-D SHARED_LIB_SUFFIX=\\".so\\"',
'-D NGRAPH_VERSION=\\"0.25.1-rc.9\\"',
'-D NGRAPH_VERSION=\\"0.25.1-rc.10\\"',
"-D NGRAPH_DEX_ONLY",
'-D PROJECT_ROOT_DIR=\\"\\"',
'-D NGRAPH_CPU_STATIC_LIB_ENABLE'
Expand Down
2 changes: 1 addition & 1 deletion build_ngtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
'''

# Component versions
ngraph_version = "v0.25.1-rc.9"
ngraph_version = "v0.25.1-rc.10"
tf_version = "v1.14.0"

# Command line parser options
Expand Down
33 changes: 31 additions & 2 deletions ngraph_bridge/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3742,6 +3742,34 @@ static Status TranslateRsqrtOp(
});
}

static Status TranslateScatterNdOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
shared_ptr<ng::Node> ng_indices;
shared_ptr<ng::Node> ng_updates;
TF_RETURN_IF_ERROR(
GetInputNodes(ng_op_map, op, &ng_indices, &ng_updates, nullptr));

std::vector<int> ng_shape;
TF_RETURN_IF_ERROR(GetStaticInputVector(op, 2, static_input_map, &ng_shape));
// Copy the int vector to a size_t vector, because that is what ng::Shape
// accepts
std::vector<size_t> ng_shape_size_t(ng_shape.begin(), ng_shape.end());

// Create a tensor and populate the tensor with "0" to Add to ScatterNd
auto et = ng_updates->get_element_type();
std::vector<std::string> constant_values(ng::shape_size(ng_shape_size_t),
"0");
auto ng_inputs = ConstructNgNode<ng::op::Constant>(
op->name(), et, ng::Shape(ng_shape_size_t), constant_values);

SaveNgOp(ng_op_map, op->name(),
ConstructNgNode<ng::op::ScatterNDAdd>(op->name(), ng_inputs,
ng_indices, ng_updates));

return Status::OK();
}

static Status TranslateRsqrtGradOp(const Node* op,
const std::vector<const Tensor*>&,
Builder::OpMap& ng_op_map) {
Expand Down Expand Up @@ -5025,8 +5053,9 @@ const static std::map<
{"ReluGrad", TranslateReluGradOp}, {"Reshape", TranslateReshapeOp},
{"ResizeBilinear", TranslateResizeBilinearOp},
{"Rsqrt", TranslateRsqrtOp}, {"RsqrtGrad", TranslateRsqrtGradOp},
{"Select", TranslateSelectOp}, {"Shape", TranslateShapeOp},
{"Sigmoid", TranslateSigmoidOp}, {"SigmoidGrad", TranslateSigmoidGradOp},
{"ScatterNd", TranslateScatterNdOp}, {"Select", TranslateSelectOp},
{"Shape", TranslateShapeOp}, {"Sigmoid", TranslateSigmoidOp},
{"SigmoidGrad", TranslateSigmoidGradOp},
{"Sin", TranslateUnaryOp<ngraph::op::Sin>}, {"Size", TranslateSizeOp},
{"Sign", TranslateUnaryOp<ngraph::op::Sign>}, {"Slice", TranslateSliceOp},
{"Snapshot", TranslateIdentityOp}, {"Softmax", TranslateSoftmaxOp},
Expand Down
4 changes: 4 additions & 0 deletions ngraph_bridge/ngraph_mark_for_clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ Status MarkForClustering(Graph* graph, const std::set<string> skip_these_nodes,
SimpleConfirmationFunction();
confirmation_function_map["Rsqrt"] = SimpleConfirmationFunction();
confirmation_function_map["RsqrtGrad"] = SimpleConfirmationFunction();
confirmation_function_map["ScatterNd"] = SimpleConfirmationFunction();
confirmation_function_map["Select"] = SimpleConfirmationFunction();
confirmation_function_map["Shape"] = SimpleConfirmationFunction();
confirmation_function_map["Sigmoid"] = SimpleConfirmationFunction();
Expand Down Expand Up @@ -568,6 +569,8 @@ Status MarkForClustering(Graph* graph, const std::set<string> skip_these_nodes,
type_constraint_map["ResizeBilinear"]["T"] = NGraphNumericDTypes();
type_constraint_map["Rsqrt"]["T"] = NGraphDTypes();
type_constraint_map["RsqrtGrad"]["T"] = NGraphRealDTypes();
type_constraint_map["ScatterNd"]["T"] = NGraphDTypes();
type_constraint_map["ScatterNd"]["Tindices"] = NGraphIndexDTypes();
type_constraint_map["Select"]["T"] = NGraphDTypes();
type_constraint_map["Shape"]["T"] = NGraphDTypes();
type_constraint_map["Shape"]["out_type"] = NGraphIndexDTypes();
Expand Down Expand Up @@ -651,6 +654,7 @@ Status MarkForClustering(Graph* graph, const std::set<string> skip_these_nodes,
return Status::OK();
};
set_attributes_map["Reshape"] = SetStaticInputs({1});
set_attributes_map["ScatterNd"] = SetStaticInputs({2});
set_attributes_map["Slice"] = SetStaticInputs({1, 2});
set_attributes_map["Split"] = SetStaticInputs({0});
set_attributes_map["SplitV"] = SetStaticInputs({1, 2});
Expand Down
70 changes: 70 additions & 0 deletions test/python/tensorflow/python_tests_list.txt
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,76 @@ relu_op_test.ReluTest.testGradientFloat32
relu_op_test.ReluTest.testGradientFloat64
relu_op_test.ReluTest.testGradientScalar

scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testBool
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testEmptyOutputShape1
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testEmptyOutputShape2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testEmptyOutputShape3
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testExtraIndicesDimensions
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank2ElementUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank2SliceUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank3SliceUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank7SliceUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testInvalidShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testRank3InvalidShape1
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testRank3InvalidShape2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testRank3ValidShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testScatterNdRepatedIndicesAdd
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch1DSliceDim2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch1DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch2DSliceDim2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch2DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testString
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testUndefinedIndicesShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testUndefinedOutputShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testUndefinedUpdatesShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.test_session
scatter_nd_ops_test.ScatterNdTensorTest.testTensorScatterUpdateWithForwarding
scatter_nd_ops_test.ScatterNdTensorTest.testUpdateAddSub
scatter_nd_ops_test.ScatterNdTensorTest.testUpdateAddSubGradients
scatter_nd_ops_test.ScatterNdTensorTest.test_session
# This test fails on CPU backend on char (and maybe other data types)
# scatter_nd_ops_test.ScatterNdTest.testBool
scatter_nd_ops_test.ScatterNdTest.testEmptyOutputShape1
# TODO: Test failing due to incorrect error message
# scatter_nd_ops_test.ScatterNdTest.testEmptyOutputShape2
scatter_nd_ops_test.ScatterNdTest.testEmptyOutputShape3
# This test fails on CPU backend on some unsupported data types
# scatter_nd_ops_test.ScatterNdTest.testExtraIndicesDimensions
scatter_nd_ops_test.ScatterNdTest.testGradientsRank2ElementUpdate
scatter_nd_ops_test.ScatterNdTest.testGradientsRank2SliceUpdate
scatter_nd_ops_test.ScatterNdTest.testGradientsRank3SliceUpdate
scatter_nd_ops_test.ScatterNdTest.testGradientsRank7SliceUpdate
scatter_nd_ops_test.ScatterNdTest.testInvalidShape
scatter_nd_ops_test.ScatterNdTest.testRank3InvalidShape1
scatter_nd_ops_test.ScatterNdTest.testRank3InvalidShape2
scatter_nd_ops_test.ScatterNdTest.testRank3ValidShape
scatter_nd_ops_test.ScatterNdTest.testScatterNdRepatedIndicesAdd
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch1DSliceDim2
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch1DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch2DSliceDim2
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch2DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdTest.testString
scatter_nd_ops_test.ScatterNdTest.testUndefinedIndicesShape
scatter_nd_ops_test.ScatterNdTest.testUndefinedOutputShape
scatter_nd_ops_test.ScatterNdTest.testUndefinedUpdatesShape
scatter_nd_ops_test.ScatterNdTest.test_session
scatter_nd_ops_test.StatefulScatterNdTest.testConcurrentUpdates
scatter_nd_ops_test.StatefulScatterNdTest.testExtraIndicesDimensions
scatter_nd_ops_test.StatefulScatterNdTest.testRank3InvalidShape1
scatter_nd_ops_test.StatefulScatterNdTest.testRank3InvalidShape2
scatter_nd_ops_test.StatefulScatterNdTest.testRank3ValidShape
scatter_nd_ops_test.StatefulScatterNdTest.testResVarInvalidOutputShape
scatter_nd_ops_test.StatefulScatterNdTest.testScatterOutOfRangeCpu
#scatter_nd_ops_test.StatefulScatterNdTest.testScatterRepeatIndices
scatter_nd_ops_test.StatefulScatterNdTest.testSimple
scatter_nd_ops_test.StatefulScatterNdTest.testSimple2
scatter_nd_ops_test.StatefulScatterNdTest.testSimple3
scatter_nd_ops_test.StatefulScatterNdTest.testSimpleResource
#scatter_nd_ops_test.StatefulScatterNdTest.testVariableRankAdd
#scatter_nd_ops_test.StatefulScatterNdTest.testVariableRankSub
#scatter_nd_ops_test.StatefulScatterNdTest.testVariableRankUpdate
scatter_nd_ops_test.StatefulScatterNdTest.test_session

slice_op_test.SliceTest.testComplex
slice_op_test.SliceTest.testEmpty
slice_op_test.SliceTest.testGradientsAll
Expand Down
67 changes: 67 additions & 0 deletions test/python/tensorflow/python_tests_list_gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,73 @@ relu_op_test.ReluTest.testGradientFloat32
relu_op_test.ReluTest.testGradientScalar
relu_op_test.ReluTest.testNumbers

scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testBool
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testEmptyOutputShape1
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testEmptyOutputShape2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testEmptyOutputShape3
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testExtraIndicesDimensions
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank2ElementUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank2SliceUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank3SliceUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testGradientsRank7SliceUpdate
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testInvalidShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testRank3InvalidShape1
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testRank3InvalidShape2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testRank3ValidShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testScatterNdRepatedIndicesAdd
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch1DSliceDim2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch1DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch2DSliceDim2
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testSmokeScatterNdBatch2DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testString
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testUndefinedIndicesShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testUndefinedOutputShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.testUndefinedUpdatesShape
scatter_nd_ops_test.ScatterNdNonAliasingAddTest.test_session
scatter_nd_ops_test.ScatterNdTensorTest.testTensorScatterUpdateWithForwarding
scatter_nd_ops_test.ScatterNdTensorTest.testUpdateAddSub
scatter_nd_ops_test.ScatterNdTensorTest.testUpdateAddSubGradients
scatter_nd_ops_test.ScatterNdTensorTest.test_session
scatter_nd_ops_test.ScatterNdTest.testBool
scatter_nd_ops_test.ScatterNdTest.testEmptyOutputShape1
scatter_nd_ops_test.ScatterNdTest.testEmptyOutputShape2
scatter_nd_ops_test.ScatterNdTest.testEmptyOutputShape3
scatter_nd_ops_test.ScatterNdTest.testExtraIndicesDimensions
scatter_nd_ops_test.ScatterNdTest.testGradientsRank2ElementUpdate
scatter_nd_ops_test.ScatterNdTest.testGradientsRank2SliceUpdate
scatter_nd_ops_test.ScatterNdTest.testGradientsRank3SliceUpdate
scatter_nd_ops_test.ScatterNdTest.testGradientsRank7SliceUpdate
scatter_nd_ops_test.ScatterNdTest.testInvalidShape
scatter_nd_ops_test.ScatterNdTest.testRank3InvalidShape1
scatter_nd_ops_test.ScatterNdTest.testRank3InvalidShape2
scatter_nd_ops_test.ScatterNdTest.testRank3ValidShape
scatter_nd_ops_test.ScatterNdTest.testScatterNdRepatedIndicesAdd
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch1DSliceDim2
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch1DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch2DSliceDim2
scatter_nd_ops_test.ScatterNdTest.testSmokeScatterNdBatch2DSliceDim3ShapeRank7
scatter_nd_ops_test.ScatterNdTest.testString
scatter_nd_ops_test.ScatterNdTest.testUndefinedIndicesShape
scatter_nd_ops_test.ScatterNdTest.testUndefinedOutputShape
scatter_nd_ops_test.ScatterNdTest.testUndefinedUpdatesShape
scatter_nd_ops_test.ScatterNdTest.test_session
scatter_nd_ops_test.StatefulScatterNdTest.testConcurrentUpdates
scatter_nd_ops_test.StatefulScatterNdTest.testExtraIndicesDimensions
scatter_nd_ops_test.StatefulScatterNdTest.testRank3InvalidShape1
scatter_nd_ops_test.StatefulScatterNdTest.testRank3InvalidShape2
scatter_nd_ops_test.StatefulScatterNdTest.testRank3ValidShape
scatter_nd_ops_test.StatefulScatterNdTest.testResVarInvalidOutputShape
scatter_nd_ops_test.StatefulScatterNdTest.testScatterOutOfRangeCpu
#scatter_nd_ops_test.StatefulScatterNdTest.testScatterRepeatIndices
scatter_nd_ops_test.StatefulScatterNdTest.testSimple
scatter_nd_ops_test.StatefulScatterNdTest.testSimple2
scatter_nd_ops_test.StatefulScatterNdTest.testSimple3
scatter_nd_ops_test.StatefulScatterNdTest.testSimpleResource
#scatter_nd_ops_test.StatefulScatterNdTest.testVariableRankAdd
#scatter_nd_ops_test.StatefulScatterNdTest.testVariableRankSub
#scatter_nd_ops_test.StatefulScatterNdTest.testVariableRankUpdate
scatter_nd_ops_test.StatefulScatterNdTest.test_session

slice_op_test.SliceTest.testComplex
slice_op_test.SliceTest.testEmpty
slice_op_test.SliceTest.testGradientsAll
Expand Down
Loading

0 comments on commit bd7a76a

Please sign in to comment.