diff --git a/flex/codegen/src/building_context.h b/flex/codegen/src/building_context.h index 414b9add95e4..7dcd907a376c 100644 --- a/flex/codegen/src/building_context.h +++ b/flex/codegen/src/building_context.h @@ -26,6 +26,8 @@ namespace gs { static constexpr const char* time_stamp = "time_stamp"; static constexpr const char* graph_var = "graph"; static constexpr const char* GRAPE_INTERFACE_CLASS = "gs::MutableCSRInterface"; +static constexpr const char* SESSION_VAR = "sess"; +static constexpr const char* SESSION_CLASS_NAME = "GraphDBSession"; static constexpr const char* GRAPE_INTERFACE_HEADER = "flex/engines/hqps_db/database/mutable_csr_interface.h"; static constexpr const char* EDGE_EXPAND_OPT_NAME = "edge_expand_opt"; @@ -303,6 +305,10 @@ class BuildingContext { std::string GraphVar() const { return graph_var; } + std::string SessionVar() const { return SESSION_VAR; } + + std::string GetSessionTypeName() const { return SESSION_CLASS_NAME; } + void AddParameterVar(const codegen::ParamConst& var) { parameter_vars_.emplace_back(var); } diff --git a/flex/codegen/src/graph_types.h b/flex/codegen/src/graph_types.h index 48a7630537a2..fa0316b399a4 100644 --- a/flex/codegen/src/graph_types.h +++ b/flex/codegen/src/graph_types.h @@ -176,7 +176,11 @@ static codegen::ParamConst param_const_pb_to_param_const( } } -static std::string data_type_2_string(const codegen::DataType& data_type) { +// The second params only control the ret value when type is string.In some +// cases, we need to use std::string_view, but in some cases, we need to use +// std::string. +static std::string data_type_2_string(const codegen::DataType& data_type, + bool string_view = true) { switch (data_type) { case codegen::DataType::kInt32: return "int32_t"; @@ -185,7 +189,11 @@ static std::string data_type_2_string(const codegen::DataType& data_type) { case codegen::DataType::kDouble: return "double"; case codegen::DataType::kString: - return "std::string_view"; + if (string_view) { + return "std::string_view"; + } else { + return "std::string"; + } case codegen::DataType::kInt64Array: return "std::vector"; case codegen::DataType::kInt32Array: diff --git a/flex/codegen/src/hqps_generator.h b/flex/codegen/src/hqps_generator.h index 3bbe5524e2c5..37fb4403c10b 100644 --- a/flex/codegen/src/hqps_generator.h +++ b/flex/codegen/src/hqps_generator.h @@ -62,27 +62,10 @@ static constexpr const char* QUERY_TEMPLATE_STR = " // constructor\n" " %3%() {}\n" "// Query function for query class\n" - " %5% Query(%7%) const{\n" + " %5% Query(%7%) override {\n" + " %4% graph(%12%);\n" " %8%\n" " }\n" - "// Wrapper query function for query class\n" - " bool DoQuery(gs::GraphDBSession& sess, Decoder& decoder, Encoder& " - "encoder) " - "override {\n" - " //decoding params from decoder, and call real query func\n" - " %9%\n" - " %4% %6%(sess);" - " auto res = Query(%10%);\n" - " // dump results to string\n" - " std::string res_str = res.SerializeAsString();\n" - " // encode results to encoder\n" - " if (!res_str.empty()){\n" - " encoder.put_string_view(res_str);\n" - " }\n" - " return true;\n" - " }\n" - " //private members\n" - " private:\n" "};\n" "} // namespace gs\n" "\n" @@ -232,8 +215,9 @@ class QueryGenerator { ss << std::endl; expr_code = ss.str(); } - std::string dynamic_vars_str = - ctx_.GetGraphInterface() + "& " + ctx_.GraphVar(); + std::string dynamic_vars_str = std::string("const ") + + ctx_.GetSessionTypeName() + "& " + + ctx_.SessionVar(); if (ctx_.GetParameterVars().size() > 0) { dynamic_vars_str += ", "; dynamic_vars_str += concat_param_vars(ctx_.GetParameterVars()); @@ -241,6 +225,7 @@ class QueryGenerator { std::string decoding_params_code, decoded_params_str; std::tie(decoding_params_code, decoded_params_str) = decode_params_from_decoder(ctx_.GetParameterVars()); + auto param_types = get_param_types(ctx_.GetParameterVars()); std::string call_query_input_code = ctx_.GraphVar(); if (decoded_params_str.size() > 0) { call_query_input_code += ", " + decoded_params_str; @@ -249,7 +234,8 @@ class QueryGenerator { formater % ctx_.GetGraphHeader() % expr_code % ctx_.GetQueryClassName() % ctx_.GetGraphInterface() % ctx_.GetQueryRet() % ctx_.GraphVar() % dynamic_vars_str % query_code % decoding_params_code % - call_query_input_code % get_app_base_name(); + call_query_input_code % get_app_base_name(param_types) % + ctx_.SessionVar(); return formater.str(); } @@ -267,7 +253,39 @@ class QueryGenerator { // This info should be parse from physical plan. // Currently always return writeAppBase, since physical plan hasn't // provided this info. - std::string get_app_base_name() { return "CypherInternalPbWriteAppBase"; } + std::string get_app_base_name(const std::vector& param_types) { + std::stringstream ss; + ss << "CypherReadAppBase<"; + for (size_t i = 0; i < param_types.size(); ++i) { + ss << param_types[i]; + if (i != param_types.size() - 1) { + ss << ", "; + } + } + ss << ">"; + return ss.str(); + } + + std::vector get_param_types( + std::vector param_vars) { + std::vector param_types; + if (param_vars.size() > 0) { + sort(param_vars.begin(), param_vars.end(), + [](const auto& a, const auto& b) { return a.id < b.id; }); + CHECK(param_vars[0].id == 0); + for (size_t i = 0; i < param_vars.size(); ++i) { + if (i > 0 && param_vars[i].id == param_vars[i - 1].id) { + CHECK(param_vars[i].var_name == param_vars[i - 1].var_name) + << " " << param_vars[i].var_name << " " + << param_vars[i - 1].var_name; + continue; + } else { + param_types.push_back(data_type_2_string(param_vars[i].type, false)); + } + } + } + return param_types; + } // copy the param vars to sort std::string concat_param_vars( @@ -285,7 +303,7 @@ class QueryGenerator { << param_vars[i - 1].var_name; continue; } else { - ss << data_type_2_string(param_vars[i].type) << " " + ss << data_type_2_string(param_vars[i].type, false) << " " << param_vars[i].var_name << ","; } } diff --git a/flex/engines/graph_db/app/builtin/count_vertices.cc b/flex/engines/graph_db/app/builtin/count_vertices.cc index e36e6a1b1398..134682becc33 100644 --- a/flex/engines/graph_db/app/builtin/count_vertices.cc +++ b/flex/engines/graph_db/app/builtin/count_vertices.cc @@ -16,19 +16,15 @@ namespace gs { -bool CountVertices::DoQuery(GraphDBSession& sess, Decoder& input, - Encoder& output) { +results::CollectiveResults CountVertices::Query(const GraphDBSession& sess, + std::string label_name) { // First get the read transaction. auto txn = sess.GetReadTransaction(); // We expect one param of type string from decoder. - if (input.empty()) { - return false; - } - std::string label_name{input.get_string()}; const auto& schema = txn.schema(); if (!schema.has_vertex_label(label_name)) { - output.put_string_view("The requested label doesn't exits."); - return false; // The requested label doesn't exits. + LOG(ERROR) << "Label " << label_name << " not found in schema."; + return results::CollectiveResults(); } auto label_id = schema.get_vertex_label_id(label_name); // The vertices are labeled internally from 0 ~ vertex_label_num, accumulate @@ -42,10 +38,7 @@ bool CountVertices::DoQuery(GraphDBSession& sess, Decoder& input, ->mutable_element() ->mutable_object() ->set_i32(vertex_num); - - output.put_string_view(results.SerializeAsString()); - txn.Commit(); - return true; + return results; } AppWrapper CountVerticesFactory::CreateApp(const GraphDB& db) { diff --git a/flex/engines/graph_db/app/builtin/count_vertices.h b/flex/engines/graph_db/app/builtin/count_vertices.h index 76d7e1bb403d..b9e063365b6e 100644 --- a/flex/engines/graph_db/app/builtin/count_vertices.h +++ b/flex/engines/graph_db/app/builtin/count_vertices.h @@ -20,10 +20,11 @@ namespace gs { // A simple app to count the number of vertices of a given label. -class CountVertices : public CypherInternalPbWriteAppBase { +class CountVertices : public CypherReadAppBase { public: CountVertices() {} - bool DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) override; + results::CollectiveResults Query(const GraphDBSession& sess, + std::string param) override; }; class CountVerticesFactory : public AppFactoryBase { diff --git a/flex/engines/graph_db/app/builtin/k_hop_neighbors.cc b/flex/engines/graph_db/app/builtin/k_hop_neighbors.cc index 6d34ed1d72f1..c022ac94884a 100644 --- a/flex/engines/graph_db/app/builtin/k_hop_neighbors.cc +++ b/flex/engines/graph_db/app/builtin/k_hop_neighbors.cc @@ -16,24 +16,20 @@ namespace gs { -bool KNeighbors::DoQuery(GraphDBSession& sess, Decoder& input, - Encoder& output) { +results::CollectiveResults KNeighbors::Query(const GraphDBSession& sess, + std::string label_name, + int64_t vertex_id, int32_t k) { auto txn = sess.GetReadTransaction(); - Schema schema_ = txn.schema(); - if (input.empty()) { - return false; - } - int64_t vertex_id_ = input.get_long(); - std::string label_name{input.get_string()}; - int k = input.get_int(); + const Schema& schema_ = txn.schema(); if (k <= 0) { - output.put_string_view("k must be greater than 0."); - return false; + LOG(ERROR) << "k must be greater than 0."; + return {}; } if (!schema_.has_vertex_label(label_name)) { - output.put_string_view("The requested label doesn't exits."); - return false; // The requested label doesn't exits. + // output.put_string_view("The requested label doesn't exits."); + LOG(ERROR) << "The requested label doesn't exits."; + return {}; } label_t vertex_label_ = schema_.get_vertex_label_id(label_name); struct pair_hash { @@ -55,9 +51,9 @@ bool KNeighbors::DoQuery(GraphDBSession& sess, Decoder& input, nei_label_.push_back(vertex_label_); vid_t vertex_index{}; - if (!txn.GetVertexIndex(vertex_label_, (int64_t) vertex_id_, vertex_index)) { - output.put_string_view("get index fail."); - return false; + if (!txn.GetVertexIndex(vertex_label_, vertex_id, vertex_index)) { + LOG(ERROR) << "Vertex not found."; + return {}; } nei_index_.push_back(vertex_index); // get k hop neighbors @@ -123,10 +119,9 @@ bool KNeighbors::DoQuery(GraphDBSession& sess, Decoder& input, ->mutable_object() ->set_i64(txn.GetVertexId(vertex_.first, vertex_.second).AsInt64()); } - output.put_string_view(results.SerializeAsString()); txn.Commit(); - return true; + return results; } AppWrapper KNeighborsFactory::CreateApp(const GraphDB& db) { diff --git a/flex/engines/graph_db/app/builtin/k_hop_neighbors.h b/flex/engines/graph_db/app/builtin/k_hop_neighbors.h index 81d58dc264ae..8c58c2efc724 100644 --- a/flex/engines/graph_db/app/builtin/k_hop_neighbors.h +++ b/flex/engines/graph_db/app/builtin/k_hop_neighbors.h @@ -19,10 +19,12 @@ #include "flex/engines/hqps_db/app/interactive_app_base.h" namespace gs { -class KNeighbors : public CypherInternalPbWriteAppBase { +class KNeighbors : public CypherReadAppBase { public: KNeighbors() {} - bool DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) override; + results::CollectiveResults Query(const GraphDBSession& sess, + std::string label_name, int64_t vertex_id, + int32_t hop_range) override; }; class KNeighborsFactory : public AppFactoryBase { diff --git a/flex/engines/graph_db/app/builtin/pagerank.cc b/flex/engines/graph_db/app/builtin/pagerank.cc index 5de165984896..a51178ba8329 100644 --- a/flex/engines/graph_db/app/builtin/pagerank.cc +++ b/flex/engines/graph_db/app/builtin/pagerank.cc @@ -16,53 +16,43 @@ namespace gs { -bool PageRank::DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) { +results::CollectiveResults PageRank::Query(const GraphDBSession& sess, + std::string vertex_label, + std::string edge_label, + double damping_factor, + int max_iterations, double epsilon) { auto txn = sess.GetReadTransaction(); - if (input.empty()) { - output.put_string_view( - "Arguments required(vertex_label, edge_label, damping_factor_, max_iterations_, epsilon_)\n \ - for example:(\"person\", \"knows\", 0.85, 100, 0.000001)"); - return false; - } - - std::string vertex_label{input.get_string()}; - std::string edge_label{input.get_string()}; - - damping_factor_ = input.get_double(); - max_iterations_ = input.get_int(); - epsilon_ = input.get_double(); if (!sess.schema().has_vertex_label(vertex_label)) { - output.put_string_view("The requested vertex label doesn't exits."); - return false; + LOG(ERROR) << "The requested vertex label doesn't exits."; + return {}; } if (!sess.schema().has_edge_label(vertex_label, vertex_label, edge_label)) { - output.put_string_view("The requested edge label doesn't exits."); - return false; + LOG(ERROR) << "The requested edge label doesn't exits."; + return {}; } - if (damping_factor_ < 0 || damping_factor_ >= 1) { - output.put_string_view( - "The value of the damping_factor_ is between 0 and 1."); - return false; + if (damping_factor < 0 || damping_factor >= 1) { + LOG(ERROR) << "The value of the damping factor is between 0 and 1."; + return {}; } - if (max_iterations_ <= 0) { - output.put_string_view("max_iterations_ must be greater than 0."); - return false; + if (max_iterations <= 0) { + LOG(ERROR) << "The value of the max iterations must be greater than 0."; + return {}; } - if (epsilon_ < 0 || epsilon_ >= 1) { - output.put_string_view("The value of the epsilon_ is between 0 and 1."); - return false; + if (epsilon < 0 || epsilon >= 1) { + LOG(ERROR) << "The value of the epsilon is between 0 and 1."; + return {}; } - vertex_label_id_ = sess.schema().get_vertex_label_id(vertex_label); - edge_label_id_ = sess.schema().get_edge_label_id(edge_label); + auto vertex_label_id = sess.schema().get_vertex_label_id(vertex_label); + auto edge_label_id = sess.schema().get_edge_label_id(edge_label); - auto num_vertices = txn.GetVertexNum(vertex_label_id_); + auto num_vertices = txn.GetVertexNum(vertex_label_id); std::unordered_map pagerank; std::unordered_map new_pagerank; - auto vertex_iter = txn.GetVertexIterator(vertex_label_id_); + auto vertex_iter = txn.GetVertexIterator(vertex_label_id); while (vertex_iter.IsValid()) { vid_t vid = vertex_iter.GetIndex(); @@ -73,23 +63,23 @@ bool PageRank::DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) { std::unordered_map outdegree; - for (int iter = 0; iter < max_iterations_; ++iter) { + for (int iter = 0; iter < max_iterations; ++iter) { for (auto& kv : new_pagerank) { kv.second = 0.0; } - auto vertex_iter = txn.GetVertexIterator(vertex_label_id_); + auto vertex_iter = txn.GetVertexIterator(vertex_label_id); while (vertex_iter.IsValid()) { vid_t v = vertex_iter.GetIndex(); double sum = 0.0; - auto edges = txn.GetInEdgeIterator(vertex_label_id_, v, vertex_label_id_, - edge_label_id_); + auto edges = txn.GetInEdgeIterator(vertex_label_id, v, vertex_label_id, + edge_label_id); while (edges.IsValid()) { auto neighbor = edges.GetNeighbor(); if (outdegree[neighbor] == 0) { auto out_edges = txn.GetOutEdgeIterator( - vertex_label_id_, neighbor, vertex_label_id_, edge_label_id_); + vertex_label_id, neighbor, vertex_label_id, edge_label_id); while (out_edges.IsValid()) { outdegree[neighbor]++; out_edges.Next(); @@ -100,7 +90,7 @@ bool PageRank::DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) { } new_pagerank[v] = - damping_factor_ * sum + (1.0 - damping_factor_) / num_vertices; + damping_factor * sum + (1.0 - damping_factor) / num_vertices; vertex_iter.Next(); } @@ -109,7 +99,7 @@ bool PageRank::DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) { diff += std::abs(new_pagerank[kv.first] - kv.second); } - if (diff < epsilon_) { + if (diff < epsilon) { break; } @@ -119,7 +109,7 @@ bool PageRank::DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) { results::CollectiveResults results; for (auto kv : pagerank) { - int64_t oid_ = txn.GetVertexId(vertex_label_id_, kv.first).AsInt64(); + int64_t oid_ = txn.GetVertexId(vertex_label_id, kv.first).AsInt64(); auto result = results.add_results(); result->mutable_record() ->add_columns() @@ -141,10 +131,8 @@ bool PageRank::DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) { ->set_f64(kv.second); } - output.put_string_view(results.SerializeAsString()); - txn.Commit(); - return true; + return results; } AppWrapper PageRankFactory::CreateApp(const GraphDB& db) { diff --git a/flex/engines/graph_db/app/builtin/pagerank.h b/flex/engines/graph_db/app/builtin/pagerank.h index abc6afc471db..fa73c6fa9fda 100644 --- a/flex/engines/graph_db/app/builtin/pagerank.h +++ b/flex/engines/graph_db/app/builtin/pagerank.h @@ -19,23 +19,15 @@ #include "flex/engines/hqps_db/app/interactive_app_base.h" namespace gs { -class PageRank : public CypherInternalPbWriteAppBase { +class PageRank + : public CypherReadAppBase { public: - PageRank() - : damping_factor_(0.85), - max_iterations_(100), - epsilon_(1e-6), - vertex_label_id_(0), - edge_label_id_(0) {} - bool DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) override; - - private: - double damping_factor_; - int max_iterations_; - double epsilon_; - - label_t vertex_label_id_; - label_t edge_label_id_; + PageRank() {} + results::CollectiveResults Query(const GraphDBSession& sess, + std::string vertex_label, + std::string edge_label, + double damping_factor, int max_iterations, + double epsilon); }; class PageRankFactory : public AppFactoryBase { diff --git a/flex/engines/graph_db/app/builtin/shortest_path_among_three.cc b/flex/engines/graph_db/app/builtin/shortest_path_among_three.cc index 15b6c4381df5..67e17b3495f0 100644 --- a/flex/engines/graph_db/app/builtin/shortest_path_among_three.cc +++ b/flex/engines/graph_db/app/builtin/shortest_path_among_three.cc @@ -16,25 +16,18 @@ namespace gs { -bool ShortestPathAmongThree::DoQuery(GraphDBSession& sess, Decoder& input, - Encoder& output) { +results::CollectiveResults ShortestPathAmongThree::Query( + const GraphDBSession& sess, std::string label_name1, int64_t oid1, + std::string label_name2, int64_t oid2, std::string label_name3, + int64_t oid3) { ReadTransaction txn = sess.GetReadTransaction(); - if (input.empty()) { - return false; - } - Schema schema_ = txn.schema(); - std::string label_name1{input.get_string()}; - int64_t vid1 = input.get_long(); - std::string label_name2{input.get_string()}; - int64_t vid2 = input.get_long(); - std::string label_name3{input.get_string()}; - int64_t vid3 = input.get_long(); + const Schema& schema_ = txn.schema(); if (!schema_.has_vertex_label(label_name1) || !schema_.has_vertex_label(label_name2) || !schema_.has_vertex_label(label_name3)) { - output.put_string_view("The requested label doesn't exits."); - return false; + LOG(ERROR) << "The requested label doesn't exits."; + return {}; } label_t label_v1 = schema_.get_vertex_label_id(label_name1); label_t label_v2 = schema_.get_vertex_label_id(label_name2); @@ -42,11 +35,11 @@ bool ShortestPathAmongThree::DoQuery(GraphDBSession& sess, Decoder& input, vid_t index_v1{}; vid_t index_v2{}; vid_t index_v3{}; - if (!txn.GetVertexIndex(label_v1, (int64_t) vid1, index_v1) || - !txn.GetVertexIndex(label_v2, (int64_t) vid2, index_v2) || - !txn.GetVertexIndex(label_v3, (int64_t) vid3, index_v3)) { - output.put_string_view("get index fail."); - return false; + if (!txn.GetVertexIndex(label_v1, oid1, index_v1) || + !txn.GetVertexIndex(label_v2, oid2, index_v2) || + !txn.GetVertexIndex(label_v3, oid3, index_v3)) { + LOG(ERROR) << "Vertex not found."; + return {}; } // get the three shortest paths std::vector> v1v2result_; @@ -99,9 +92,8 @@ bool ShortestPathAmongThree::DoQuery(GraphDBSession& sess, Decoder& input, ->mutable_object() ->set_str(result_path); - output.put_string_view(results.SerializeAsString()); txn.Commit(); - return true; + return results; } bool ShortestPathAmongThree::ShortestPath( diff --git a/flex/engines/graph_db/app/builtin/shortest_path_among_three.h b/flex/engines/graph_db/app/builtin/shortest_path_among_three.h index 71235c2e1b35..beebc3642e81 100644 --- a/flex/engines/graph_db/app/builtin/shortest_path_among_three.h +++ b/flex/engines/graph_db/app/builtin/shortest_path_among_three.h @@ -19,10 +19,15 @@ #include "flex/engines/hqps_db/app/interactive_app_base.h" namespace gs { -class ShortestPathAmongThree : public CypherInternalPbWriteAppBase { +class ShortestPathAmongThree + : public CypherReadAppBase { public: ShortestPathAmongThree() {} - bool DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) override; + results::CollectiveResults Query(const GraphDBSession& sess, + std::string label_name1, int64_t oid1, + std::string label_name2, int64_t oid2, + std::string label_name3, int64_t oid3); private: bool ShortestPath(const gs::ReadTransaction& txn, label_t v1_l, diff --git a/flex/engines/graph_db/app/hqps_app.cc b/flex/engines/graph_db/app/hqps_app.cc index 6412ff1cfcf6..5457916fecc1 100644 --- a/flex/engines/graph_db/app/hqps_app.cc +++ b/flex/engines/graph_db/app/hqps_app.cc @@ -55,8 +55,11 @@ bool HQPSAdhocReadApp::Query(const GraphDBSession& graph, Decoder& input, << app_wrapper.app()->mode(); return false; } + // Adhoc read app should not have input, so we pass an empty decoder + std::vector dummy_input; + gs::Decoder dummy_decoder(dummy_input.data(), dummy_input.size()); auto casted = dynamic_cast(app_wrapper.app()); - return casted->Query(graph, input, output); + return casted->Query(graph, dummy_decoder, output); } bool HQPSAdhocWriteApp::Query(GraphDBSession& graph, Decoder& input, diff --git a/flex/engines/graph_db/database/graph_db_session.cc b/flex/engines/graph_db/database/graph_db_session.cc index 05f31700f85d..96e51f2d70f2 100644 --- a/flex/engines/graph_db/database/graph_db_session.cc +++ b/flex/engines/graph_db/database/graph_db_session.cc @@ -247,7 +247,8 @@ GraphDBSession::parse_query_type_from_cypher_json( const std::string_view& str_view) { VLOG(10) << "string view: " << str_view; rapidjson::Document j; - if (j.Parse(std::string(str_view.data(), str_view.size())).HasParseError()) { + if (j.Parse(std::string(str_view.data(), str_view.size() - 1)) + .HasParseError()) { LOG(ERROR) << "Fail to parse json from input content"; return Result>(gs::Status( StatusCode::INTERNAL_ERROR, "Fail to parse json from input content")); @@ -272,7 +273,7 @@ Result> GraphDBSession::parse_query_type_from_cypher_internal( const std::string_view& str_view) { procedure::Query cur_query; - if (!cur_query.ParseFromArray(str_view.data(), str_view.size())) { + if (!cur_query.ParseFromArray(str_view.data(), str_view.size() - 1)) { LOG(ERROR) << "Fail to parse query from input content"; return Result>(gs::Status( StatusCode::INTERNAL_ERROR, "Fail to parse query from input content")); diff --git a/flex/engines/graph_db/database/graph_db_session.h b/flex/engines/graph_db/database/graph_db_session.h index 158511299e10..918e22d62514 100644 --- a/flex/engines/graph_db/database/graph_db_session.h +++ b/flex/engines/graph_db/database/graph_db_session.h @@ -158,15 +158,17 @@ class GraphDBSession { } else if (input_tag == static_cast(InputFormat::kCypherJson)) { // For cypherJson there is no query-id provided. The query name is // provided in the json string. - std::string_view str_view(input.data(), len - 1); + // We don't discard the last byte, since we need it to determine the input + // format when deserializing the input arguments in deserialize() function + std::string_view str_view(input.data(), len); return parse_query_type_from_cypher_json(str_view); } else if (input_tag == static_cast(InputFormat::kCypherProtoProcedure)) { // For cypher internal procedure, the query_name is // provided in the protobuf message. - std::string_view str_view(input.data(), len - 1); + // Same as cypherJson, we don't discard the last byte. + std::string_view str_view(input.data(), len); return parse_query_type_from_cypher_internal(str_view); - } else { return Result>( gs::Status(StatusCode::INVALID_ARGUMENT, diff --git a/flex/engines/hqps_db/app/interactive_app_base.h b/flex/engines/hqps_db/app/interactive_app_base.h index 78b29a719c70..5186379ad948 100644 --- a/flex/engines/hqps_db/app/interactive_app_base.h +++ b/flex/engines/hqps_db/app/interactive_app_base.h @@ -18,6 +18,7 @@ #include #include "flex/engines/graph_db/app/app_base.h" +#include "flex/engines/graph_db/database/graph_db_session.h" #include "flex/proto_generated_gie/results.pb.h" #include "flex/proto_generated_gie/stored_procedure.pb.h" #include "flex/utils/property/types.h" @@ -25,45 +26,78 @@ namespace gs { -inline void put_argument(gs::Encoder& encoder, - const procedure::Argument& argument) { - auto& value = argument.value(); - auto item_case = value.item_case(); - switch (item_case) { - case common::Value::kI32: - encoder.put_int(value.i32()); - break; - case common::Value::kI64: - encoder.put_long(value.i64()); - break; - case common::Value::kF64: - encoder.put_double(value.f64()); - break; - case common::Value::kStr: - encoder.put_string(value.str()); - break; - default: - LOG(ERROR) << "Not recognizable param type" << static_cast(item_case); +template +inline bool parse_input_argument_from_proto_impl( + TUPLE_T& tuple, + const google::protobuf::RepeatedPtrField& args) { + if constexpr (I == sizeof...(ARGS)) { + return true; + } else { + auto& type = std::get(tuple); + auto& argument = args.Get(I); + auto& value = argument.value(); + auto item_case = value.item_case(); + if (item_case == common::Value::kI32) { + if constexpr (std::is_same>::value) { + type = value.i32(); + } else { + LOG(ERROR) << "Type mismatch: " << item_case << "at " << I; + return false; + } + } else if (item_case == common::Value::kI64) { + if constexpr (std::is_same>::value) { + type = value.i64(); + } else { + LOG(ERROR) << "Type mismatch: " << item_case << "at " << I; + return false; + } + } else if (item_case == common::Value::kF64) { + if constexpr (std::is_same>::value) { + type = value.f64(); + } else { + LOG(ERROR) << "Type mismatch: " << item_case << "at " << I; + return false; + } + } else if (item_case == common::Value::kStr) { + if constexpr (std::is_same>::value) { + type = value.str(); + } else { + LOG(ERROR) << "Type mismatch: " << item_case << "at " << I; + return false; + } + } else { + LOG(ERROR) << "Not recognizable param type" << item_case; + return false; + } + return parse_input_argument_from_proto_impl(tuple, + args); } } -inline bool parse_input_argument(gs::Decoder& raw_input, - gs::Encoder& argument_encoder) { - if (raw_input.size() == 0) { +template +inline bool parse_input_argument_from_proto(std::tuple& tuple, + std::string_view sv) { + if (sv.size() == 0) { VLOG(10) << "No arguments found in input"; return true; } procedure::Query cur_query; - if (!cur_query.ParseFromArray(raw_input.data(), raw_input.size())) { + if (!cur_query.ParseFromArray(sv.data(), sv.size())) { LOG(ERROR) << "Fail to parse query from input content"; return false; } auto& args = cur_query.arguments(); - for (int32_t i = 0; i < args.size(); ++i) { - put_argument(argument_encoder, args[i]); + if (args.size() != sizeof...(ARGS)) { + LOG(ERROR) << "Arguments size mismatch: " << args.size() << " vs " + << sizeof...(ARGS); + return false; } - VLOG(10) << ", num args: " << args.size(); - return true; + return parse_input_argument_from_proto_impl<0, std::tuple, ARGS...>( + tuple, args); } class GraphDBSession; @@ -110,9 +144,14 @@ bool deserialize_impl(TUPLE_T& tuple, const rapidjson::Value& json) { } template -bool deserialize(std::tuple& tuple, std::string_view sv) { +bool parse_input_argument_from_json(std::tuple& tuple, + std::string_view sv) { rapidjson::Document j; VLOG(10) << "parsing string: " << sv << ",size" << sv.size(); + if (sv.empty()) { + LOG(INFO) << "No arguments found in input"; + return sizeof...(ARGS) == 0; + } if (j.Parse(std::string(sv)).HasParseError()) { LOG(ERROR) << "Fail to parse json from input content"; return false; @@ -140,6 +179,27 @@ bool deserialize(std::tuple& tuple, std::string_view sv) { } } +template +bool deserialize(std::tuple& tuple, std::string_view sv) { + // Deserialize input argument from the payload. The last byte is the input + // format, could only be kCypherJson or kCypherProtoProcedure. + if (sv.empty()) { + return sizeof...(ARGS) == 0; + } + auto input_format = static_cast(sv.back()); + std::string_view payload(sv.data(), sv.size() - 1); + if (input_format == + static_cast(GraphDBSession::InputFormat::kCypherJson)) { + return parse_input_argument_from_json(tuple, payload); + } else if (input_format == + static_cast( + GraphDBSession::InputFormat::kCypherProtoProcedure)) { + return parse_input_argument_from_proto(tuple, payload); + } else { + LOG(ERROR) << "Invalid input format: " << input_format; + return false; + } +} // for cypher procedure template class CypherReadAppBase : public ReadAppBase { @@ -210,26 +270,6 @@ class CypherWriteAppBase : public WriteAppBase { } }; -// For internal cypher-gen procedure, with pb input and output -// Codegen app should inherit from this class -class CypherInternalPbWriteAppBase : public WriteAppBase { - public: - AppType type() const override { return AppType::kCypherProcedure; } - - virtual bool DoQuery(GraphDBSession& db, Decoder& input, Encoder& output) = 0; - - bool Query(GraphDBSession& db, Decoder& raw_input, Encoder& output) override { - std::vector output_buffer; - gs::Encoder argument_encoder(output_buffer); - if (!parse_input_argument(raw_input, argument_encoder)) { - LOG(ERROR) << "Failed to parse input argument!"; - return false; - } - gs::Decoder argument_decoder(output_buffer.data(), output_buffer.size()); - return DoQuery(db, argument_decoder, output); - } -}; - } // namespace gs #endif // ENGINES_HQPS_DB_APP_INTERACTIVE_APP_BASE_H_ diff --git a/flex/engines/http_server/handler/graph_db_http_handler.cc b/flex/engines/http_server/handler/graph_db_http_handler.cc index 2b5391d35683..c5ca6c8b9b8f 100644 --- a/flex/engines/http_server/handler/graph_db_http_handler.cc +++ b/flex/engines/http_server/handler/graph_db_http_handler.cc @@ -747,7 +747,7 @@ class adhoc_query_handler : public StoppableHandler { // TODO(zhanglei): choose read or write based on the request, after the // read/write info is supported in physical plan // The content contains the path to dynamic library - param.content.append(gs::Schema::HQPS_ADHOC_WRITE_PLUGIN_ID_STR, 1); + param.content.append(gs::Schema::HQPS_ADHOC_READ_PLUGIN_ID_STR, 1); param.content.append(gs::GraphDBSession::kCypherProtoAdhocStr, 1); return get_executors()[StoppableHandler::shard_id()][dst_executor] .run_graph_db_query(query_param{std::move(param.content)}) diff --git a/flex/interactive/sdk/python/gs_interactive/tests/test_driver.py b/flex/interactive/sdk/python/gs_interactive/tests/test_driver.py index ebf2b33b0c51..35b58f6fac74 100644 --- a/flex/interactive/sdk/python/gs_interactive/tests/test_driver.py +++ b/flex/interactive/sdk/python/gs_interactive/tests/test_driver.py @@ -372,7 +372,7 @@ def createCypherProcedure(self): create_proc_request = CreateProcedureRequest( name=self._cypher_proc_name, description="test procedure", - query="MATCH (n) RETURN COUNT(n);", + query="MATCH (n: person) where n.name =$personName RETURN COUNT(n);", type="cypher", ) resp = self._sess.create_procedure(self._graph_id, create_proc_request) @@ -508,7 +508,7 @@ def getStatistics(self): def callProcedure(self): with self._driver.getNeo4jSession() as session: - result = session.run("CALL test_procedure();") + result = session.run('CALL test_procedure("marko");') print("call procedure result: ", result) def callPrcedureWithServiceStop(self): diff --git a/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py b/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py index 0dfd26fcbad3..63ffff495ef3 100644 --- a/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py +++ b/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py @@ -220,8 +220,8 @@ def test_builtin_procedure(interactive_session, neo4j_session, create_modern_gra neo4j_session, create_modern_graph, "k_neighbors", - "1L", '"person"', + "1L", "2", ) diff --git a/flex/storages/metadata/graph_meta_store.cc b/flex/storages/metadata/graph_meta_store.cc index 3d668f35cc91..10b9f079782f 100644 --- a/flex/storages/metadata/graph_meta_store.cc +++ b/flex/storages/metadata/graph_meta_store.cc @@ -82,11 +82,11 @@ const std::vector& get_builtin_plugin_metas() { pagerank.update_time = GetCurrentTimeStamp(); pagerank.params.push_back({"vertex_label", PropertyType::kString}); pagerank.params.push_back({"edge_label", PropertyType::kString}); - pagerank.params.push_back({"damping_factor_", PropertyType::kDouble}); - pagerank.params.push_back({"max_iterations_", PropertyType::kInt32}); - pagerank.params.push_back({"epsilon_", PropertyType::kDouble}); - pagerank.returns.push_back({"label name", PropertyType::kString}); - pagerank.returns.push_back({"vertex oid", PropertyType::kInt64}); + pagerank.params.push_back({"damping_factor", PropertyType::kDouble}); + pagerank.params.push_back({"max_iterations", PropertyType::kInt32}); + pagerank.params.push_back({"epsilon", PropertyType::kDouble}); + pagerank.returns.push_back({"label_name", PropertyType::kString}); + pagerank.returns.push_back({"vertex_oid", PropertyType::kInt64}); pagerank.returns.push_back({"pagerank", PropertyType::kDouble}); builtin_plugins.push_back(pagerank); @@ -100,11 +100,11 @@ const std::vector& get_builtin_plugin_metas() { k_neighbors.type = "cypher"; k_neighbors.creation_time = GetCurrentTimeStamp(); k_neighbors.update_time = GetCurrentTimeStamp(); - k_neighbors.params.push_back({"vid", PropertyType::kInt64}); k_neighbors.params.push_back({"label_name", PropertyType::kString}); + k_neighbors.params.push_back({"oid", PropertyType::kInt64}); k_neighbors.params.push_back({"k", PropertyType::kInt32}); - k_neighbors.returns.push_back({"label name", PropertyType::kString}); - k_neighbors.returns.push_back({"vertex oid", PropertyType::kInt64}); + k_neighbors.returns.push_back({"label_name", PropertyType::kString}); + k_neighbors.returns.push_back({"vertex_oid", PropertyType::kInt64}); builtin_plugins.push_back(k_neighbors); // shortest_path_among_three @@ -120,17 +120,16 @@ const std::vector& get_builtin_plugin_metas() { shortest_path_among_three.update_time = GetCurrentTimeStamp(); shortest_path_among_three.params.push_back( {"label_name1", PropertyType::kString}); - shortest_path_among_three.params.push_back({"vid1", PropertyType::kInt64}); + shortest_path_among_three.params.push_back({"oid1", PropertyType::kInt64}); shortest_path_among_three.params.push_back( {"label_name2", PropertyType::kString}); - shortest_path_among_three.params.push_back({"vid2", PropertyType::kInt64}); + shortest_path_among_three.params.push_back({"oid2", PropertyType::kInt64}); shortest_path_among_three.params.push_back( {"label_name3", PropertyType::kString}); - shortest_path_among_three.params.push_back({"vid3", PropertyType::kInt64}); + shortest_path_among_three.params.push_back({"oid3", PropertyType::kInt64}); shortest_path_among_three.returns.push_back( {"shortest_path_among_three (label name, vertex oid)", PropertyType::kString}); - initialized = true; builtin_plugins.push_back(shortest_path_among_three); initialized = true;