Skip to content

Commit

Permalink
refactor(interactive): Refactor the codegen and argument parsing (#4246)
Browse files Browse the repository at this point in the history
- Refactor the codegen for stored procedure, let adhoc query and
procedure query all inherit from `AdhocReadApp`.
- Refine the `create_procedure` and `call_procedure` test in sdk test.
- Fix argument parsing problem.
  • Loading branch information
zhanglei1949 authored Oct 12, 2024
1 parent 44599e4 commit 5d89a61
Show file tree
Hide file tree
Showing 19 changed files with 257 additions and 212 deletions.
6 changes: 6 additions & 0 deletions flex/codegen/src/building_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace gs {
static constexpr const char* time_stamp = "time_stamp";
static constexpr const char* graph_var = "graph";
static constexpr const char* GRAPE_INTERFACE_CLASS = "gs::MutableCSRInterface";
static constexpr const char* SESSION_VAR = "sess";
static constexpr const char* SESSION_CLASS_NAME = "GraphDBSession";
static constexpr const char* GRAPE_INTERFACE_HEADER =
"flex/engines/hqps_db/database/mutable_csr_interface.h";
static constexpr const char* EDGE_EXPAND_OPT_NAME = "edge_expand_opt";
Expand Down Expand Up @@ -303,6 +305,10 @@ class BuildingContext {

std::string GraphVar() const { return graph_var; }

std::string SessionVar() const { return SESSION_VAR; }

std::string GetSessionTypeName() const { return SESSION_CLASS_NAME; }

void AddParameterVar(const codegen::ParamConst& var) {
parameter_vars_.emplace_back(var);
}
Expand Down
12 changes: 10 additions & 2 deletions flex/codegen/src/graph_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ static codegen::ParamConst param_const_pb_to_param_const(
}
}

static std::string data_type_2_string(const codegen::DataType& data_type) {
// The second params only control the ret value when type is string.In some
// cases, we need to use std::string_view, but in some cases, we need to use
// std::string.
static std::string data_type_2_string(const codegen::DataType& data_type,
bool string_view = true) {
switch (data_type) {
case codegen::DataType::kInt32:
return "int32_t";
Expand All @@ -185,7 +189,11 @@ static std::string data_type_2_string(const codegen::DataType& data_type) {
case codegen::DataType::kDouble:
return "double";
case codegen::DataType::kString:
return "std::string_view";
if (string_view) {
return "std::string_view";
} else {
return "std::string";
}
case codegen::DataType::kInt64Array:
return "std::vector<int64_t>";
case codegen::DataType::kInt32Array:
Expand Down
66 changes: 42 additions & 24 deletions flex/codegen/src/hqps_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,10 @@ static constexpr const char* QUERY_TEMPLATE_STR =
" // constructor\n"
" %3%() {}\n"
"// Query function for query class\n"
" %5% Query(%7%) const{\n"
" %5% Query(%7%) override {\n"
" %4% graph(%12%);\n"
" %8%\n"
" }\n"
"// Wrapper query function for query class\n"
" bool DoQuery(gs::GraphDBSession& sess, Decoder& decoder, Encoder& "
"encoder) "
"override {\n"
" //decoding params from decoder, and call real query func\n"
" %9%\n"
" %4% %6%(sess);"
" auto res = Query(%10%);\n"
" // dump results to string\n"
" std::string res_str = res.SerializeAsString();\n"
" // encode results to encoder\n"
" if (!res_str.empty()){\n"
" encoder.put_string_view(res_str);\n"
" }\n"
" return true;\n"
" }\n"
" //private members\n"
" private:\n"
"};\n"
"} // namespace gs\n"
"\n"
Expand Down Expand Up @@ -232,15 +215,17 @@ class QueryGenerator {
ss << std::endl;
expr_code = ss.str();
}
std::string dynamic_vars_str =
ctx_.GetGraphInterface() + "& " + ctx_.GraphVar();
std::string dynamic_vars_str = std::string("const ") +
ctx_.GetSessionTypeName() + "& " +
ctx_.SessionVar();
if (ctx_.GetParameterVars().size() > 0) {
dynamic_vars_str += ", ";
dynamic_vars_str += concat_param_vars(ctx_.GetParameterVars());
}
std::string decoding_params_code, decoded_params_str;
std::tie(decoding_params_code, decoded_params_str) =
decode_params_from_decoder(ctx_.GetParameterVars());
auto param_types = get_param_types(ctx_.GetParameterVars());
std::string call_query_input_code = ctx_.GraphVar();
if (decoded_params_str.size() > 0) {
call_query_input_code += ", " + decoded_params_str;
Expand All @@ -249,7 +234,8 @@ class QueryGenerator {
formater % ctx_.GetGraphHeader() % expr_code % ctx_.GetQueryClassName() %
ctx_.GetGraphInterface() % ctx_.GetQueryRet() % ctx_.GraphVar() %
dynamic_vars_str % query_code % decoding_params_code %
call_query_input_code % get_app_base_name();
call_query_input_code % get_app_base_name(param_types) %
ctx_.SessionVar();
return formater.str();
}

Expand All @@ -267,7 +253,39 @@ class QueryGenerator {
// This info should be parse from physical plan.
// Currently always return writeAppBase, since physical plan hasn't
// provided this info.
std::string get_app_base_name() { return "CypherInternalPbWriteAppBase"; }
std::string get_app_base_name(const std::vector<std::string>& param_types) {
std::stringstream ss;
ss << "CypherReadAppBase<";
for (size_t i = 0; i < param_types.size(); ++i) {
ss << param_types[i];
if (i != param_types.size() - 1) {
ss << ", ";
}
}
ss << ">";
return ss.str();
}

std::vector<std::string> get_param_types(
std::vector<codegen::ParamConst> param_vars) {
std::vector<std::string> param_types;
if (param_vars.size() > 0) {
sort(param_vars.begin(), param_vars.end(),
[](const auto& a, const auto& b) { return a.id < b.id; });
CHECK(param_vars[0].id == 0);
for (size_t i = 0; i < param_vars.size(); ++i) {
if (i > 0 && param_vars[i].id == param_vars[i - 1].id) {
CHECK(param_vars[i].var_name == param_vars[i - 1].var_name)
<< " " << param_vars[i].var_name << " "
<< param_vars[i - 1].var_name;
continue;
} else {
param_types.push_back(data_type_2_string(param_vars[i].type, false));
}
}
}
return param_types;
}

// copy the param vars to sort
std::string concat_param_vars(
Expand All @@ -285,7 +303,7 @@ class QueryGenerator {
<< param_vars[i - 1].var_name;
continue;
} else {
ss << data_type_2_string(param_vars[i].type) << " "
ss << data_type_2_string(param_vars[i].type, false) << " "
<< param_vars[i].var_name << ",";
}
}
Expand Down
17 changes: 5 additions & 12 deletions flex/engines/graph_db/app/builtin/count_vertices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@

namespace gs {

bool CountVertices::DoQuery(GraphDBSession& sess, Decoder& input,
Encoder& output) {
results::CollectiveResults CountVertices::Query(const GraphDBSession& sess,
std::string label_name) {
// First get the read transaction.
auto txn = sess.GetReadTransaction();
// We expect one param of type string from decoder.
if (input.empty()) {
return false;
}
std::string label_name{input.get_string()};
const auto& schema = txn.schema();
if (!schema.has_vertex_label(label_name)) {
output.put_string_view("The requested label doesn't exits.");
return false; // The requested label doesn't exits.
LOG(ERROR) << "Label " << label_name << " not found in schema.";
return results::CollectiveResults();
}
auto label_id = schema.get_vertex_label_id(label_name);
// The vertices are labeled internally from 0 ~ vertex_label_num, accumulate
Expand All @@ -42,10 +38,7 @@ bool CountVertices::DoQuery(GraphDBSession& sess, Decoder& input,
->mutable_element()
->mutable_object()
->set_i32(vertex_num);

output.put_string_view(results.SerializeAsString());
txn.Commit();
return true;
return results;
}

AppWrapper CountVerticesFactory::CreateApp(const GraphDB& db) {
Expand Down
5 changes: 3 additions & 2 deletions flex/engines/graph_db/app/builtin/count_vertices.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@

namespace gs {
// A simple app to count the number of vertices of a given label.
class CountVertices : public CypherInternalPbWriteAppBase {
class CountVertices : public CypherReadAppBase<std::string> {
public:
CountVertices() {}
bool DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) override;
results::CollectiveResults Query(const GraphDBSession& sess,
std::string param) override;
};

class CountVerticesFactory : public AppFactoryBase {
Expand Down
31 changes: 13 additions & 18 deletions flex/engines/graph_db/app/builtin/k_hop_neighbors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,20 @@

namespace gs {

bool KNeighbors::DoQuery(GraphDBSession& sess, Decoder& input,
Encoder& output) {
results::CollectiveResults KNeighbors::Query(const GraphDBSession& sess,
std::string label_name,
int64_t vertex_id, int32_t k) {
auto txn = sess.GetReadTransaction();
Schema schema_ = txn.schema();
if (input.empty()) {
return false;
}
int64_t vertex_id_ = input.get_long();
std::string label_name{input.get_string()};
int k = input.get_int();
const Schema& schema_ = txn.schema();

if (k <= 0) {
output.put_string_view("k must be greater than 0.");
return false;
LOG(ERROR) << "k must be greater than 0.";
return {};
}
if (!schema_.has_vertex_label(label_name)) {
output.put_string_view("The requested label doesn't exits.");
return false; // The requested label doesn't exits.
// output.put_string_view("The requested label doesn't exits.");
LOG(ERROR) << "The requested label doesn't exits.";
return {};
}
label_t vertex_label_ = schema_.get_vertex_label_id(label_name);
struct pair_hash {
Expand All @@ -55,9 +51,9 @@ bool KNeighbors::DoQuery(GraphDBSession& sess, Decoder& input,

nei_label_.push_back(vertex_label_);
vid_t vertex_index{};
if (!txn.GetVertexIndex(vertex_label_, (int64_t) vertex_id_, vertex_index)) {
output.put_string_view("get index fail.");
return false;
if (!txn.GetVertexIndex(vertex_label_, vertex_id, vertex_index)) {
LOG(ERROR) << "Vertex not found.";
return {};
}
nei_index_.push_back(vertex_index);
// get k hop neighbors
Expand Down Expand Up @@ -123,10 +119,9 @@ bool KNeighbors::DoQuery(GraphDBSession& sess, Decoder& input,
->mutable_object()
->set_i64(txn.GetVertexId(vertex_.first, vertex_.second).AsInt64());
}
output.put_string_view(results.SerializeAsString());

txn.Commit();
return true;
return results;
}

AppWrapper KNeighborsFactory::CreateApp(const GraphDB& db) {
Expand Down
6 changes: 4 additions & 2 deletions flex/engines/graph_db/app/builtin/k_hop_neighbors.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
#include "flex/engines/hqps_db/app/interactive_app_base.h"

namespace gs {
class KNeighbors : public CypherInternalPbWriteAppBase {
class KNeighbors : public CypherReadAppBase<std::string, int64_t, int32_t> {
public:
KNeighbors() {}
bool DoQuery(GraphDBSession& sess, Decoder& input, Encoder& output) override;
results::CollectiveResults Query(const GraphDBSession& sess,
std::string label_name, int64_t vertex_id,
int32_t hop_range) override;
};

class KNeighborsFactory : public AppFactoryBase {
Expand Down
Loading

0 comments on commit 5d89a61

Please sign in to comment.