Skip to content

Commit

Permalink
add codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
liulx20 committed Aug 15, 2024
1 parent 1a31576 commit 8facee5
Show file tree
Hide file tree
Showing 26 changed files with 1,871 additions and 164 deletions.
7 changes: 7 additions & 0 deletions flex/engines/graph_db/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,10 @@ target_link_libraries(runtime_adhoc runtime_common)
install_flex_target(runtime_adhoc)


file(GLOB_RECURSE CODEGEN_SOURCES "codegen/*.cc" "common/types.cc")
add_library(runtime_codegen SHARED ${CODEGEN_SOURCES})
target_link_libraries(runtime_codegen hqps_plan_proto flex_utils flex_rt_mutable_graph)
install_flex_target(runtime_codegen)

add_executable(codegen codegen/codegen.cc)
target_link_libraries(codegen runtime_codegen)
6 changes: 3 additions & 3 deletions flex/engines/graph_db/runtime/adhoc/operators/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ Context eval_scan(const physical::Scan& scan_opr, const ReadTransaction& txn,

bool scan_oid;
if (is_find_vertex(scan_opr, params, label, vertex_id, alias, scan_oid)) {
return Scan::find_vertex(txn, label, vertex_id, alias, scan_oid);
return Scan::find_vertex_with_id(txn, label, vertex_id, alias, scan_oid);
}

const auto& opt = scan_opr.scan_opt();
Expand Down Expand Up @@ -191,7 +191,7 @@ Context eval_scan(const physical::Scan& scan_opr, const ReadTransaction& txn,
expr->eval_vertex(label, vid, 0).as_bool();
});
} else {
return Scan::scan_gid_vertex(
return Scan::filter_gids(
txn, scan_params,
[&expr, oids](label_t label, vid_t vid) {
return expr->eval_vertex(label, vid, 0).as_bool();
Expand All @@ -213,7 +213,7 @@ Context eval_scan(const physical::Scan& scan_opr, const ReadTransaction& txn,
oids.end();
});
} else {
return Scan::scan_gid_vertex(
return Scan::filter_gids(
txn, scan_params, [](label_t, vid_t) { return true; }, oids);
}
}
Expand Down
6 changes: 0 additions & 6 deletions flex/engines/graph_db/runtime/adhoc/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ namespace gs {

namespace runtime {

enum class VarType {
kVertexVar,
kEdgeVar,
kPathVar,
};

class VarGetterBase {
public:
virtual ~VarGetterBase() = default;
Expand Down
40 changes: 40 additions & 0 deletions flex/engines/graph_db/runtime/codegen/builders/builders.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/** Copyright 2020 Alibaba Group Holding Limited.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef RUNTIME_CODEGEN_BUILDERS_BUILDERS_H_
#define RUNTIME_CODEGEN_BUILDERS_BUILDERS_H_
#include <sstream>
#include <string>

#include "flex/engines/graph_db/runtime/codegen/building_context.h"
#include "flex/proto_generated_gie/algebra.pb.h"
#include "flex/proto_generated_gie/common.pb.h"
#include "flex/proto_generated_gie/expr.pb.h"
#include "flex/proto_generated_gie/physical.pb.h"

namespace gs {
namespace runtime {

std::string BuildScan(BuildingContext& context, const physical::Scan& opr);

std::string BuildSink(BuildingContext& context);

std::string BuildLimit(BuildingContext& context, const algebra::Limit& opr);

std::string BuildGetV(BuildingContext& context, const physical::GetV& opr);

} // namespace runtime
} // namespace gs
#endif // RUNTIME_CODEGEN_BUILDERS_BUILDERS_H_
18 changes: 18 additions & 0 deletions flex/engines/graph_db/runtime/codegen/builders/get_v_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "flex/engines/graph_db/runtime/codegen/builders/builders.h"
namespace gs {
namespace runtime {
class GetVBuilder {
public:
GetVBuilder(BuildingContext& context) : context_(context) {};
std::string Build(const physical::GetV& opr) {
std::stringstream ss;
ss << "GetVBuilder::Build()";
return ss.str();
}
BuildingContext& context_;
};
std::string BuildGetV(BuildingContext& context, const physical::GetV& opr) {
return GetVBuilder(context).Build(opr);
}
} // namespace runtime
} // namespace gs
38 changes: 38 additions & 0 deletions flex/engines/graph_db/runtime/codegen/builders/limit_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "flex/engines/graph_db/runtime/codegen/builders/builders.h"

namespace gs {
namespace runtime {
class LimitBuilder {
public:
LimitBuilder(BuildingContext& context) : context_(context) {};

std::string Build() { return ""; }

LimitBuilder& Lower(int lower) {
lower_ = lower;
return *this;
}
LimitBuilder& Upper(int upper) {
upper_ = upper;
return *this;
}

BuildingContext& context_;
int lower_;
int upper_;
};

std::string BuildLimit(BuildingContext& context, const algebra::Limit& opr) {
LimitBuilder builder(context);
int lower = 0;
int upper = std::numeric_limits<int>::max();
if (opr.has_range()) {
lower = std::max(lower, static_cast<int>(opr.range().lower()));
upper = std::min(upper, static_cast<int>(opr.range().upper()));
}

// TODO
return builder.Build();
}
} // namespace runtime
} // namespace gs
234 changes: 234 additions & 0 deletions flex/engines/graph_db/runtime/codegen/builders/scan_builder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
#include "flex/engines/graph_db/runtime/codegen/builders/builders.h"
#include "flex/engines/graph_db/runtime/codegen/exprs/expr_builder.h"
#include "flex/engines/graph_db/runtime/codegen/exprs/expr_utils.h"
#include "flex/engines/graph_db/runtime/codegen/utils/utils.h"
#include "flex/engines/graph_db/runtime/common/utils.h"

namespace gs {
namespace runtime {
class ScanBuilder {
public:
ScanBuilder(BuildingContext& context) : context_(context) {}

bool is_find_vertex(const physical::Scan& scan_opr, label_t& label,
int64_t& vertex_id, int& alias, bool& scan_oid,
std::string& expr_name, std::string& expr_str) const {
if (scan_opr.scan_opt() != physical::Scan::VERTEX) {
return false;
}
if (scan_opr.has_alias()) {
alias = scan_opr.alias().value();
} else {
alias = -1;
}
if (!scan_opr.has_params()) {
return false;
}

const auto& params = scan_opr.params();
if (params.tables_size() != 1) {
return false;
}
const auto& table = params.tables(0);
label = static_cast<int>(table.id());
if (!scan_opr.has_idx_predicate()) {
return false;
}
const auto& idx_predicate = scan_opr.idx_predicate();
if (idx_predicate.or_predicates_size() != 1) {
return false;
}

if (idx_predicate.or_predicates(0).predicates_size() != 1) {
return false;
}
const auto& predicate = idx_predicate.or_predicates(0).predicates(0);
if (!predicate.has_key()) {
return false;
}
auto key = predicate.key();
if (key.has_key()) {
scan_oid = true;
} else if (key.has_id()) {
// scan gid
scan_oid = false;
} else {
LOG(FATAL) << "Invalid key type";
}
switch (predicate.value_case()) {
case algebra::IndexPredicate_Triplet::ValueCase::kConst: {
RTAnyType type;
std::tie(expr_str, expr_name, type) =
value_pb_2_str(context_, predicate.const_());
break;
}
case algebra::IndexPredicate_Triplet::ValueCase::kParam: {
RTAnyType type;
std::tie(expr_str, expr_name, type) =
param_pb_2_str(context_, predicate.param());
break;
}
default:
LOG(FATAL) << "Invalid value type";
}
return true;
}

Check notice on line 76 in flex/engines/graph_db/runtime/codegen/builders/scan_builder.cc

View check run for this annotation

codefactor.io / CodeFactor

flex/engines/graph_db/runtime/codegen/builders/scan_builder.cc#L13-L76

Complex Method
bool parse_idx_predicate(const algebra::IndexPredicate& predicate,
std::vector<int64_t>& oids, bool& scan_oid) const {
if (predicate.or_predicates_size() != 1) {
return false;
}

if (predicate.or_predicates(0).predicates_size() != 1) {
return false;
}
const auto& triplet = predicate.or_predicates(0).predicates(0);
if (!triplet.has_key()) {
return false;
}
auto key = triplet.key();
if (key.has_key()) {
scan_oid = true;
} else if (key.has_id()) {
scan_oid = false;
} else {
LOG(FATAL) << "unexpected key case";
}

if (triplet.cmp() != common::Logical::EQ &&
triplet.cmp() != common::Logical::WITHIN) {
return false;
}

if (triplet.value_case() ==
algebra::IndexPredicate_Triplet::ValueCase::kConst) {
if (triplet.const_().item_case() == common::Value::kI32) {
oids.push_back(triplet.const_().i32());
} else if (triplet.const_().item_case() == common::Value::kI64) {
oids.push_back(triplet.const_().i64());
} else if (triplet.const_().item_case() == common::Value::kI64Array) {
const auto& array = triplet.const_().i64_array();
for (int i = 0; i < array.item_size(); ++i) {
oids.push_back(array.item(i));
}

} else {
LOG(FATAL) << "unexpected value case" << triplet.const_().item_case();
}
}
return true;
}

std::string Build(const physical::Scan& scan_opr) {
{
label_t label;
int64_t vertex_id;
int alias;
bool scan_oid;
std::string expr_name, expr_str;
std::stringstream ss;
if (is_find_vertex(scan_opr, label, vertex_id, alias, scan_oid, expr_name,
expr_str)) {
auto ctx = context_.GetNextCtxName();
ss << expr_str << "\n";
ss << ctx << " = find_vertex(txn, " << static_cast<int>(label) << ", "
<< expr_name << ", " << alias << ", "
<< (scan_oid ? "true" : "false") << ");\n";
context_.set_alias(alias, ContextColumnType::kVertex,
RTAnyType::kVertex);
return ss.str();
}
}
const auto& opt = scan_opr.scan_opt();
CHECK(opt == physical::Scan::VERTEX) << "Unsupported scan option";
ScanParams scan_params;
if (scan_opr.has_alias()) {
scan_params.alias = scan_opr.alias().value();
} else {
scan_params.alias = -1;
}
context_.set_alias(scan_params.alias, ContextColumnType::kVertex,
RTAnyType::kVertex);
CHECK(scan_opr.has_params()) << "Scan params is not set";
const auto& scan_opr_params = scan_opr.params();
for (const auto& table : scan_opr_params.tables()) {
scan_params.tables.push_back(table.id());
}

if (scan_opr_params.has_predicate()) {
auto [name, str] =
BuildExpr(context_, scan_opr_params.predicate(), VarType::kVertexVar);
if (scan_opr.has_idx_predicate()) {
const auto& idx_predicate = scan_opr.idx_predicate();
std::vector<int64_t> oids;
bool scan_oid;
CHECK(parse_idx_predicate(idx_predicate, oids, scan_oid))
<< "Invalid idx predicate";
std::stringstream ss;
ss << str << "\n";
if (scan_oid) {
ss << "Scan::filter_oids(txn, " << scan_params.toString() << ", ["
<< name << "](label_t label, vid_t vid){\n return " << name
<< ".typed_eval_vertex(label, vid, 0);\n"
<< "}, ";
} else {
ss << "Scan::filter_gids(txn, " << scan_params.toString() << ", ["
<< name << "](label_t label, vid_t vid){\n return " << name
<< ".typed_eval_vertex(label, vid, 0);\n"
<< "}, ";
}
ss << vec_2_str(oids) << ");\n";
return ss.str();
} else {
std::stringstream ss;
ss << str << "\n";
ss << "Scan::scan_vertex(txn, " << scan_params.toString() << ", ["
<< name << "](label_t label, vid_t vid){\n return " << name
<< ".typed_eval_vertex(label, vid, 0);\n"
<< "});\n";
return ss.str();
}
}
if (scan_opr.has_idx_predicate()) {
const auto& idx_predicate = scan_opr.idx_predicate();
std::vector<int64_t> oids;
bool scan_oid;
CHECK(parse_idx_predicate(idx_predicate, oids, scan_oid))
<< "Invalid idx predicate";
if (scan_oid) {
std::stringstream ss;
ss << "Scan::filter_oids(txn, " << scan_params.toString()
<< "[](label_t label, vid_t vid){\n"
<< "return true;\n"
<< "}, " << vec_2_str(oids) << ");\n";
return ss.str();
} else {
std::stringstream ss;
ss << "Scan::filter_gids(txn, " << scan_params.toString()
<< "[](label_t label, vid_t vid){\n"
<< "return true;\n"
<< "}, " << vec_2_str(oids) << ");\n";
return ss.str();
}

} else {
std::stringstream ss;
ss << "Scan::scan_vertex(txn, " << scan_params.toString()
<< "[](label_t label, vid_t vid){\n"
<< "return true;\n"
<< "});\n";
return ss.str();
}
LOG(FATAL) << "not support to reach here";
return "";
}

BuildingContext& context_;
};

std::string BuildScan(BuildingContext& context, const physical::Scan& opr) {
return ScanBuilder(context).Build(opr);
}
} // namespace runtime
} // namespace gs
Loading

0 comments on commit 8facee5

Please sign in to comment.