diff --git a/docs/interactive_engine/neo4j/supported_cypher.md b/docs/interactive_engine/neo4j/supported_cypher.md index b14b46b24c4c..21261ad9cb65 100644 --- a/docs/interactive_engine/neo4j/supported_cypher.md +++ b/docs/interactive_engine/neo4j/supported_cypher.md @@ -98,6 +98,8 @@ Note that some Aggregator operators, such as `max()`, we listed here are impleme | Branch | Use with `Project` and `Return` | CASE WHEN | CASE WHEN | | planned | | Scalar | Returns the length of a path | length() | length() | | | | List | Fold expressions into a single list | [] | [] | | | +| Labels | Get label name of a vertex type | labels() | labels() | | | +| Type | Get label name of an edge type | type() | type() | | | diff --git a/flex/codegen/src/graph_types.h b/flex/codegen/src/graph_types.h index 8b8b074bb791..7f5a44525b34 100644 --- a/flex/codegen/src/graph_types.h +++ b/flex/codegen/src/graph_types.h @@ -43,6 +43,7 @@ enum class DataType { kTime = 11, kDate = 12, kDateTime = 13, + kLabelId = 14 }; // a parameter const, the real data will be feed at runtime. @@ -193,6 +194,8 @@ static std::string data_type_2_string(const codegen::DataType& data_type) { return EDGE_ID_T; case codegen::DataType::kDate: return "Date"; + case codegen::DataType::kLabelId: + return "LabelKey"; default: // LOG(FATAL) << "unknown data type" << static_cast(data_type); throw std::runtime_error( diff --git a/flex/codegen/src/hqps/hqps_project_builder.h b/flex/codegen/src/hqps/hqps_project_builder.h index 6e3653cad443..acf5350640f7 100644 --- a/flex/codegen/src/hqps/hqps_project_builder.h +++ b/flex/codegen/src/hqps/hqps_project_builder.h @@ -214,6 +214,10 @@ std::string project_variable_mapping_to_string(BuildingContext& ctx, } else if (prop.item_case() == common::Property::kLen) { prop_names.push_back("length"); data_types.push_back(codegen::DataType::kLength); + } else if (prop.item_case() == common::Property::kLabel) { + // return the label id. + prop_names.push_back("label"); + data_types.push_back(codegen::DataType::kLabelId); } else { LOG(FATAL) << "Unknown property type" << prop.DebugString(); } diff --git a/flex/engines/hqps_db/core/operator/project.h b/flex/engines/hqps_db/core/operator/project.h index 08ccab28f1c5..0f60d0a7c213 100644 --- a/flex/engines/hqps_db/core/operator/project.h +++ b/flex/engines/hqps_db/core/operator/project.h @@ -291,8 +291,31 @@ class ProjectOp { ///////////////////Project implementation for all data structures. + /// Special case for project for labelKey + template < + typename T, typename NODE_T, + typename std::enable_if>::type* = nullptr> + static auto apply_single_project_impl( + const GRAPH_INTERFACE& graph, NODE_T& node, const std::string& prop_name, + const std::vector& repeat_array) { + auto size = node.Size(); + auto label_vec = node.GetLabelVec(); + std::vector res_prop_vec; + CHECK(label_vec.size() == repeat_array.size()) + << "label size: " << label_vec.size() + << " repeat size: " << repeat_array.size(); + for (auto i = 0; i < repeat_array.size(); ++i) { + for (auto j = 0; j < repeat_array[i]; ++j) { + res_prop_vec.emplace_back(label_vec[i]); + } + } + return Collection(std::move(res_prop_vec)); + } + // single label vertex set. - template + template < + typename T, typename LabelT, typename VID_T, typename... SET_T, + typename std::enable_if<(!std::is_same_v)>::type* = nullptr> static auto apply_single_project_impl( const GRAPH_INTERFACE& graph, RowVertexSetImpl& node, @@ -315,8 +338,10 @@ class ProjectOp { } // single keyed label vertex set. - template + template < + typename T, typename LabelT, typename KEY_T, typename VID_T, + typename... SET_T, + typename std::enable_if<(!std::is_same_v)>::type* = nullptr> static auto apply_single_project_impl( const GRAPH_INTERFACE& graph, KeyedRowVertexSetImpl& node, @@ -340,7 +365,9 @@ class ProjectOp { } // project for two label vertex set. - template + template < + typename T, typename VID_T, typename LabelT, typename... SET_T, + typename std::enable_if<(!std::is_same_v)>::type* = nullptr> static auto apply_single_project_impl( const GRAPH_INTERFACE& graph, TwoLabelVertexSetImpl& node, @@ -379,7 +406,9 @@ class ProjectOp { } // general vertex set. - template + template < + typename T, typename VID_T, typename LabelT, + typename std::enable_if<(!std::is_same_v)>::type* = nullptr> static auto apply_single_project_impl( const GRAPH_INTERFACE& graph, GeneralVertexSet& node, const std::string& prop_name_, const std::vector& repeat_array) { @@ -419,8 +448,10 @@ class ProjectOp { } // single label edge set - template ::type* = nullptr> + template < + typename T, typename NODE_T, + typename std::enable_if)>::type* = nullptr> static auto apply_single_project_impl( const GRAPH_INTERFACE& graph, NODE_T& node, const std::string& prop_name, const std::vector& repeat_array) { @@ -449,7 +480,9 @@ class ProjectOp { } /// Apply project on untyped edge set. - template + template < + typename T, typename VID_T, typename LabelT, typename SUB_GRAPH_T, + typename std::enable_if<(!std::is_same_v)>::type* = nullptr> static auto apply_single_project_impl( const GRAPH_INTERFACE& graph, UnTypedEdgeSet& node, diff --git a/flex/engines/hqps_db/core/operator/sink.h b/flex/engines/hqps_db/core/operator/sink.h index 0513ab38398a..ab3a0ac27f85 100644 --- a/flex/engines/hqps_db/core/operator/sink.h +++ b/flex/engines/hqps_db/core/operator/sink.h @@ -378,10 +378,11 @@ class SinkOp { } } - // sink collection of pod + // sink collection of pod, expect for LabelKey type template ::value) && - (!gs::is_tuple::value)>::type* = nullptr> + typename std::enable_if< + (!gs::is_vector::value) && (!gs::is_tuple::value) && + (!std::is_same::value)>::type* = nullptr> static void sink_col_impl(results::CollectiveResults& results_vec, const Collection& collection, const std::vector& repeat_offsets, @@ -417,6 +418,43 @@ class SinkOp { } } + // sink collection of LabelKey + template + static void sink_col_impl(results::CollectiveResults& results_vec, + const Collection& collection, + const std::vector& repeat_offsets, + int32_t tag_id) { + if (repeat_offsets.empty()) { + CHECK(collection.Size() == results_vec.results_size()) + << "size neq " << collection.Size() << " " + << results_vec.results_size(); + for (auto i = 0; i < collection.Size(); ++i) { + auto row = results_vec.mutable_results(i); + CHECK(row->record().columns_size() == Ind); + auto record = row->mutable_record(); + auto new_col = record->add_columns(); + new_col->mutable_name_or_id()->set_id(tag_id); + auto obj = + new_col->mutable_entry()->mutable_element()->mutable_object(); + obj->set_i32(collection.Get(i).label_id); + } + } else { + CHECK(repeat_offsets.size() == collection.Size()); + size_t cur_ind = 0; + for (auto i = 0; i < collection.Size(); ++i) { + for (auto j = 0; j < repeat_offsets[i]; ++j) { + auto row = results_vec.mutable_results(cur_ind++); + auto record = row->mutable_record(); + auto new_col = record->add_columns(); + new_col->mutable_name_or_id()->set_id(tag_id); + auto obj = + new_col->mutable_entry()->mutable_element()->mutable_object(); + obj->set_i32(collection.Get(i).label_id); + } + } + } + } + // sinke for tuple with one element template struct is_label_key_prop : std::false_type {}; diff --git a/flex/engines/hqps_db/core/utils/hqps_utils.h b/flex/engines/hqps_db/core/utils/hqps_utils.h index 0d7e90a36dcc..a7f24ae210e2 100644 --- a/flex/engines/hqps_db/core/utils/hqps_utils.h +++ b/flex/engines/hqps_db/core/utils/hqps_utils.h @@ -847,6 +847,13 @@ struct to_string_impl { } }; +template <> +struct to_string_impl { + static inline std::string to_string(const LabelKey& label_key) { + return std::to_string(label_key.label_id); + } +}; + template <> struct to_string_impl { static inline std::string to_string(const Direction& opt) { diff --git a/flex/engines/hqps_db/structures/multi_edge_set/adj_edge_set.h b/flex/engines/hqps_db/structures/multi_edge_set/adj_edge_set.h index 4eae9d2ce7af..e2f5c9ac297e 100644 --- a/flex/engines/hqps_db/structures/multi_edge_set/adj_edge_set.h +++ b/flex/engines/hqps_db/structures/multi_edge_set/adj_edge_set.h @@ -298,6 +298,12 @@ class AdjEdgeSet { return builder_t(src_label_, dst_label_, edge_label_, prop_names_, dir_); } + std::vector GetLabelVec() const { + std::vector label_vec(Size()); + std::fill(label_vec.begin(), label_vec.end(), {edge_label_}); + return label_vec; + } + iterator begin() const { return iterator(vids_, adj_lists_, 0); } iterator end() const { return iterator(vids_, adj_lists_, vids_.size()); } @@ -414,6 +420,12 @@ class AdjEdgeSet { iterator end() const { return iterator(vids_, adj_lists_, vids_.size()); } + std::vector GetLabelVec() const { + std::vector label_vec(Size()); + std::fill(label_vec.begin(), label_vec.end(), {edge_label_}); + return label_vec; + } + size_t Size() const { return size_; } template diff --git a/flex/engines/hqps_db/structures/multi_edge_set/flat_edge_set.h b/flex/engines/hqps_db/structures/multi_edge_set/flat_edge_set.h index 4448523b5c48..df4a20b6bfaf 100644 --- a/flex/engines/hqps_db/structures/multi_edge_set/flat_edge_set.h +++ b/flex/engines/hqps_db/structures/multi_edge_set/flat_edge_set.h @@ -173,6 +173,16 @@ class FlatEdgeSet { iterator end() const { return iterator(vec_, vec_.size()); } + std::vector GetLabelVec() const { + std::vector res; + res.reserve(vec_.size()); + for (auto i = 0; i < vec_.size(); ++i) { + auto ind = label_triplet_ind_[i]; + res.emplace_back(label_triplet_[ind][2]); + } + return res; + } + template flat_t Flat( std::vector>& index_ele_tuple) const { @@ -498,6 +508,15 @@ class SingleLabelEdgeSet { iterator end() const { return iterator(vec_, vec_.size()); } + std::vector GetLabelVec() const { + std::vector res; + res.reserve(vec_.size()); + for (auto i = 0; i < vec_.size(); ++i) { + res.emplace_back(label_triplet_[2]); + } + return res; + } + template flat_t Flat( std::vector>& index_ele_tuple) const { diff --git a/flex/engines/hqps_db/structures/multi_edge_set/general_edge_set.h b/flex/engines/hqps_db/structures/multi_edge_set/general_edge_set.h index eccacc2a117d..46d1a17f33cf 100644 --- a/flex/engines/hqps_db/structures/multi_edge_set/general_edge_set.h +++ b/flex/engines/hqps_db/structures/multi_edge_set/general_edge_set.h @@ -428,6 +428,15 @@ class GeneralEdgeSet<2, GI, VID_T, LabelT, std::tuple, std::tuple> { iterator end() const { return iterator(vids_, adj_lists_, vids_.size()); } + std::vector GetLabelVec() const { + std::vector res; + res.reserve(Size()); + for (auto i = 0; i < Size(); ++i) { + res.emplace_back(edge_label_); + } + return res; + } + size_t Size() const { if (size_ == 0) { for (auto i = 0; i < adj_lists_.size(); ++i) { @@ -612,6 +621,15 @@ class GeneralEdgeSet<2, GI, VID_T, LabelT, std::tuple, bitsets_.swap(other.bitsets_); } + std::vector GetLabelVec() const { + std::vector res; + res.reserve(Size()); + for (auto i = 0; i < Size(); ++i) { + res.emplace_back(edge_label_); + } + return res; + } + iterator begin() const { return iterator(vids_, adj_lists_, 0); } iterator end() const { return iterator(vids_, adj_lists_, vids_.size()); } diff --git a/flex/engines/hqps_db/structures/multi_edge_set/untyped_edge_set.h b/flex/engines/hqps_db/structures/multi_edge_set/untyped_edge_set.h index fcdc9d1a42a8..cb86de6dc2f6 100644 --- a/flex/engines/hqps_db/structures/multi_edge_set/untyped_edge_set.h +++ b/flex/engines/hqps_db/structures/multi_edge_set/untyped_edge_set.h @@ -195,6 +195,26 @@ class UnTypedEdgeSet { src_vertices_.size()); } + std::vector GetLabelVec() const { + std::vector res; + res.reserve(Size()); + for (auto i = 0; i < src_vertices_.size(); ++i) { + auto label_ind = label_indices_[i]; + auto label = src_labels_[label_ind]; + if (adj_lists_.find(label) != adj_lists_.end()) { + auto& sub_graphs = adj_lists_.at(label); + for (auto& sub_graph : sub_graphs) { + auto edge_iters = sub_graph.get_edges(src_vertices_[i]); + auto edge_label = sub_graph.GetEdgeLabel(); + for (auto j = 0; j < edge_iters.Size(); ++j) { + res.emplace_back(edge_label); + } + } + } + } + return res; + } + size_t Size() const { if (size_ == 0) { auto iter_vec = generate_iters(); diff --git a/flex/engines/hqps_db/structures/multi_vertex_set/general_vertex_set.h b/flex/engines/hqps_db/structures/multi_vertex_set/general_vertex_set.h index 097bd008bda8..df086f248744 100644 --- a/flex/engines/hqps_db/structures/multi_vertex_set/general_vertex_set.h +++ b/flex/engines/hqps_db/structures/multi_vertex_set/general_vertex_set.h @@ -78,7 +78,7 @@ auto general_project_vertices_impl( break; } } else { - if (expr(eles)) { + if (expr(std::get<0>(eles))) { res_bitsets[label_id].set_bit(res_vec.size()); res_vec.push_back(old_vec[i]); break; @@ -381,6 +381,20 @@ class GeneralVertexSet { LabelT GetLabel(size_t i) const { return label_names_[i]; } + const std::vector GetLabelVec() const { + std::vector res; + // fill res with vertex labels + for (auto i = 0; i < vec_.size(); ++i) { + for (auto j = 0; j < bitsets_.size(); ++j) { + if (bitsets_[j].get_bit(i)) { + res.emplace_back(label_names_[j]); + break; + } + } + } + return res; + } + // generate label indices. std::vector GenerateLabelIndices() const { std::vector label_indices; diff --git a/flex/engines/hqps_db/structures/multi_vertex_set/keyed_row_vertex_set.h b/flex/engines/hqps_db/structures/multi_vertex_set/keyed_row_vertex_set.h index af70aa8379f0..170a10972e41 100644 --- a/flex/engines/hqps_db/structures/multi_vertex_set/keyed_row_vertex_set.h +++ b/flex/engines/hqps_db/structures/multi_vertex_set/keyed_row_vertex_set.h @@ -569,6 +569,16 @@ class KeyedRowVertexSetImpl { LabelT GetLabel() const { return v_label_; } + const std::vector GetLabelVec() { + std::vector res; + // fill res with vertex labels + res.reserve(vids_.size()); + for (auto i = 0; i < vids_.size(); ++i) { + res.emplace_back(v_label_); + } + return res; + } + const std::array& GetPropNames() const { return prop_names_; } @@ -768,6 +778,16 @@ class KeyedRowVertexSetImpl { LabelT GetLabel() const { return v_label_; } + const std::vector GetLabelVec() { + std::vector res; + // fill res with vertex labels + res.reserve(vids_.size()); + for (auto i = 0; i < vids_.size(); ++i) { + res.emplace_back(v_label_); + } + return res; + } + const std::vector& GetVertices() const { return vids_; } builder_t CreateBuilder() const { return builder_t(v_label_); } diff --git a/flex/engines/hqps_db/structures/multi_vertex_set/row_vertex_set.h b/flex/engines/hqps_db/structures/multi_vertex_set/row_vertex_set.h index 6f3a4ef6cb54..696eda5146c9 100644 --- a/flex/engines/hqps_db/structures/multi_vertex_set/row_vertex_set.h +++ b/flex/engines/hqps_db/structures/multi_vertex_set/row_vertex_set.h @@ -950,6 +950,16 @@ class RowVertexSetImpl { const LabelT& GetLabel() const { return v_label_; } + const std::vector GetLabelVec() { + std::vector res; + // fill res with v_label_ + res.reserve(vids_.size()); + for (auto i = 0; i < vids_.size(); ++i) { + res.emplace_back(v_label_); + } + return res; + } + const std::vector& GetVertices() const { return vids_; } std::vector& GetMutableVertices() { return vids_; } @@ -1244,6 +1254,16 @@ class RowVertexSetImpl { const LabelT& GetLabel() const { return v_label_; } + const std::vector GetLabelVec() { + std::vector res; + // fill res with v_label_ + res.reserve(vids_.size()); + for (auto i = 0; i < vids_.size(); ++i) { + res.emplace_back(v_label_); + } + return res; + } + const std::vector& GetVertices() const { return vids_; } std::vector& GetMutableVertices() { return vids_; } std::vector&& MoveVertices() { return std::move(vids_); } diff --git a/flex/engines/hqps_db/structures/multi_vertex_set/two_label_vertex_set.h b/flex/engines/hqps_db/structures/multi_vertex_set/two_label_vertex_set.h index d5c008f3cac2..dea4fb335f3e 100644 --- a/flex/engines/hqps_db/structures/multi_vertex_set/two_label_vertex_set.h +++ b/flex/engines/hqps_db/structures/multi_vertex_set/two_label_vertex_set.h @@ -744,6 +744,19 @@ class TwoLabelVertexSetImpl { const std::array& GetLabels() const { return label_names_; } + const std::vector GetLabelVec() { + std::vector res; + // fill with each vertex's label + for (auto i = 0; i < vec_.size(); ++i) { + if (bitset_.get_bit(i)) { + res.emplace_back(label_names_[0]); + } else { + res.emplace_back(label_names_[1]); + } + } + return res; + } + LabelT GetLabel(size_t i) const { return label_names_[i]; } const grape::Bitset& GetBitset() const { return bitset_; } @@ -1006,6 +1019,19 @@ class TwoLabelVertexSetImpl { const std::array& GetLabels() const { return label_names_; } + const std::vector GetLabelVec() { + std::vector res; + // fill with each vertex's label + for (auto i = 0; i < vec_.size(); ++i) { + if (bitset_.get_bit(i)) { + res.emplace_back(label_names_[0]); + } else { + res.emplace_back(label_names_[1]); + } + } + return res; + } + LabelT GetLabel(size_t i) const { return label_names_[i]; } const grape::Bitset& GetBitset() const { return bitset_; } diff --git a/flex/tests/hqps/match_query.h b/flex/tests/hqps/match_query.h index a770d8e4d0e3..6b424c99151d 100644 --- a/flex/tests/hqps/match_query.h +++ b/flex/tests/hqps/match_query.h @@ -623,5 +623,126 @@ class MatchQuery10 : public HqpsAppBase { } }; +struct MatchQuery11Expr0 { + public: + using result_t = bool; + MatchQuery11Expr0() {} + + inline auto operator()(int64_t id) const { return (true) && (id == 933); } + + private: +}; +struct MatchQuery11Expr1 { + public: + using result_t = bool; + MatchQuery11Expr1() {} + + inline auto operator()(int64_t id) const { + return (true) && (id == 2199023256077); + } + + private: +}; + +// Auto generated query class definition +class MatchQuery11 : public HqpsAppBase { + public: + using Engine = SyncEngine; + using label_id_t = typename gs::MutableCSRInterface::label_id_t; + using vertex_id_t = typename gs::MutableCSRInterface::vertex_id_t; + // Query function for query class + results::CollectiveResults Query(const gs::MutableCSRInterface& graph) const { + auto expr0 = gs::make_filter(MatchQuery11Expr0(), + gs::PropertySelector("id")); + auto ctx0 = Engine::template ScanVertex( + graph, std::array{0, 1, 2, 3, 4, 5, 6, 7}, + std::move(expr0)); + + auto edge_expand_opt0 = gs::make_edge_expand_multie_opt< + label_id_t, std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple>( + gs::Direction::Both, + std::array, 21>{ + std::array{2, 2, 2}, + std::array{2, 3, 2}, + std::array{1, 7, 6}, + std::array{6, 6, 13}, + std::array{4, 3, 3}, + std::array{2, 0, 7}, + std::array{1, 0, 7}, + std::array{3, 0, 7}, + std::array{5, 0, 7}, + std::array{1, 1, 8}, + std::array{1, 2, 9}, + std::array{1, 3, 9}, + std::array{0, 0, 11}, + std::array{7, 6, 12}, + std::array{2, 1, 0}, + std::array{3, 1, 0}, + std::array{1, 5, 10}, + std::array{4, 1, 4}, + std::array{1, 5, 14}, + std::array{3, 7, 1}, + std::array{4, 1, 5}}, + std::tuple{PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{"creationDate"}, + PropTupleArrayT>{"creationDate"}, + PropTupleArrayT>{"creationDate"}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{}, + PropTupleArrayT>{"workFrom"}, + PropTupleArrayT>{"joinDate"}, + PropTupleArrayT>{"classYear"}, + PropTupleArrayT>{}, + PropTupleArrayT>{}}); + auto ctx1 = + Engine::template EdgeExpandE( + graph, std::move(ctx0), std::move(edge_expand_opt0)); + + auto get_v_opt1 = + make_getv_opt(gs::VOpt::Other, std::array{}); + auto ctx2 = Engine::template GetV( + graph, std::move(ctx1), std::move(get_v_opt1)); + auto expr1 = gs::make_filter(MatchQuery11Expr1(), + gs::PropertySelector("id")); + auto get_v_opt2 = make_getv_opt( + gs::VOpt::Itself, std::array{}, std::move(expr1)); + auto ctx3 = Engine::template GetV( + graph, std::move(ctx2), std::move(get_v_opt2)); + auto ctx4 = Engine::Project( + graph, std::move(ctx3), + std::tuple{gs::make_mapper_with_variable( + gs::PropertySelector("label")), + gs::make_mapper_with_variable( + gs::PropertySelector("label"))}); + return Engine::Sink(ctx4, std::array{0, 1}); + } + // Wrapper query function for query class + results::CollectiveResults Query(const gs::MutableCSRInterface& graph, + Decoder& decoder) const override { + // decoding params from decoder, and call real query func + + return Query(graph); + } +}; + } // namespace gs #endif // TESTS_HQPS_MATCH_QUERY_H_ \ No newline at end of file diff --git a/flex/tests/hqps/query_test.cc b/flex/tests/hqps/query_test.cc index 1674238b978a..f59c5a73e8d9 100644 --- a/flex/tests/hqps/query_test.cc +++ b/flex/tests/hqps/query_test.cc @@ -201,5 +201,18 @@ int main(int argc, char** argv) { LOG(INFO) << "Finish MatchQuery10 test"; } + { + gs::MatchQuery11 query; + std::vector encoder_array; + gs::Encoder input_encoder(encoder_array); + std::vector output_array; + gs::Encoder output(output_array); + gs::Decoder input(encoder_array.data(), encoder_array.size()); + + gs::MutableCSRInterface graph(sess); + query.Query(graph, input); + LOG(INFO) << "Finish MatchQuery10 test"; + } + LOG(INFO) << "Finish context test."; } \ No newline at end of file diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GraphOptTable.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GraphOptTable.java index 4c9fed3c3b16..d47111c41906 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GraphOptTable.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/meta/schema/GraphOptTable.java @@ -21,7 +21,6 @@ import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; import com.alibaba.graphscope.common.ir.type.GraphLabelType; import com.alibaba.graphscope.common.ir.type.GraphSchemaType; -import com.alibaba.graphscope.common.ir.type.GraphSchemaTypeList; import com.alibaba.graphscope.groot.common.schema.api.*; import org.apache.calcite.linq4j.tree.Expression; @@ -76,27 +75,32 @@ private RelDataType deriveType(GraphElement element) { } if (element instanceof GraphVertex) { GraphLabelType labelType = - (new GraphLabelType()).label(element.getLabel()).labelId(element.getLabelId()); + new GraphLabelType( + new GraphLabelType.Entry() + .label(element.getLabel()) + .labelId(element.getLabelId())); return new GraphSchemaType(GraphOpt.Source.VERTEX, labelType, fields); } else if (element instanceof GraphEdge) { GraphEdge edge = (GraphEdge) element; List relations = edge.getRelationList(); List fuzzyTypes = new ArrayList<>(); for (EdgeRelation relation : relations) { - GraphLabelType labelType = - (new GraphLabelType()) + GraphLabelType.Entry labelEntry = + new GraphLabelType.Entry() .label(element.getLabel()) .labelId(element.getLabelId()); GraphVertex src = relation.getSource(); GraphVertex dst = relation.getTarget(); - labelType.srcLabel(src.getLabel()).dstLabel(dst.getLabel()); - labelType.srcLabelId(src.getLabelId()).dstLabelId(dst.getLabelId()); - fuzzyTypes.add(new GraphSchemaType(GraphOpt.Source.EDGE, labelType, fields)); + labelEntry.srcLabel(src.getLabel()).dstLabel(dst.getLabel()); + labelEntry.srcLabelId(src.getLabelId()).dstLabelId(dst.getLabelId()); + fuzzyTypes.add( + new GraphSchemaType( + GraphOpt.Source.EDGE, new GraphLabelType(labelEntry), fields)); } ObjectUtils.requireNonEmpty(fuzzyTypes); return (fuzzyTypes.size() == 1) ? fuzzyTypes.get(0) - : GraphSchemaTypeList.create(fuzzyTypes); + : GraphSchemaType.create(fuzzyTypes, getRelOptSchema().getTypeFactory()); } else { throw new IllegalArgumentException("element should be vertex or edge"); } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/graph/AbstractBindableTableScan.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/graph/AbstractBindableTableScan.java index 81e55ce4e5c9..ba61a19b188e 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/graph/AbstractBindableTableScan.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/rel/graph/AbstractBindableTableScan.java @@ -19,7 +19,6 @@ import com.alibaba.graphscope.common.ir.rel.type.TableConfig; import com.alibaba.graphscope.common.ir.tools.AliasInference; import com.alibaba.graphscope.common.ir.type.GraphSchemaType; -import com.alibaba.graphscope.common.ir.type.GraphSchemaTypeList; import com.google.common.collect.ImmutableList; import org.apache.calcite.plan.GraphOptCluster; @@ -30,6 +29,7 @@ import org.apache.calcite.rel.core.TableScan; import org.apache.calcite.rel.hint.RelHint; import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeFieldImpl; import org.apache.calcite.rel.type.RelRecordType; import org.apache.calcite.rex.RexNode; @@ -91,20 +91,17 @@ protected AbstractBindableTableScan( public RelDataType deriveRowType() { List tableTypes = new ArrayList<>(); List tables = ObjectUtils.requireNonEmpty(this.tableConfig.getTables()); + RelDataTypeFactory typeFactory = tables.get(0).getRelOptSchema().getTypeFactory(); for (RelOptTable table : tables) { GraphSchemaType type = (GraphSchemaType) table.getRowType(); // flat fuzzy labels to the list - if (type instanceof GraphSchemaTypeList) { - tableTypes.addAll((GraphSchemaTypeList) type); - } else { - tableTypes.add(type); - } + tableTypes.addAll(type.getSchemaTypeAsList()); } ObjectUtils.requireNonEmpty(tableTypes); GraphSchemaType graphType = (tableTypes.size() == 1) ? tableTypes.get(0) - : GraphSchemaTypeList.create(tableTypes); + : GraphSchemaType.create(tableTypes, typeFactory); RelRecordType rowType = new RelRecordType( ImmutableList.of( diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java index 0c051e1fc7b3..f9efd30c80c9 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/ffi/RelToFfiConverter.java @@ -35,7 +35,6 @@ import com.alibaba.graphscope.common.ir.type.GraphLabelType; import com.alibaba.graphscope.common.ir.type.GraphProperty; import com.alibaba.graphscope.common.ir.type.GraphSchemaType; -import com.alibaba.graphscope.common.ir.type.GraphSchemaTypeList; import com.alibaba.graphscope.common.jna.IrCoreLibrary; import com.alibaba.graphscope.common.jna.type.*; import com.alibaba.graphscope.gaia.proto.OuterExpression; @@ -442,7 +441,7 @@ private List getLeftRightVariables(RexNode condition) { private Pointer ffiQueryParams(AbstractBindableTableScan tableScan) { Set uniqueLabelIds = - getGraphLabels(tableScan).stream() + getGraphLabels(tableScan).getLabelsEntry().stream() .map(k -> k.getLabelId()) .collect(Collectors.toSet()); Pointer params = LIB.initQueryParams(); @@ -570,7 +569,7 @@ private void addFfiBinder(Pointer ptrSentence, RelNode binder, boolean isTail) { private void addFilterToFfiBinder(Pointer ptrSentence, AbstractBindableTableScan tableScan) { Set uniqueLabelIds = - getGraphLabels(tableScan).stream() + getGraphLabels(tableScan).getLabelsEntry().stream() .map(k -> k.getLabelId()) .collect(Collectors.toSet()); // add labels as select operator @@ -602,20 +601,14 @@ private void addFilterToFfiBinder(Pointer ptrSentence, AbstractBindableTableScan } } - private List getGraphLabels(AbstractBindableTableScan tableScan) { + private GraphLabelType getGraphLabels(AbstractBindableTableScan tableScan) { List fields = tableScan.getRowType().getFieldList(); Preconditions.checkArgument( !fields.isEmpty() && fields.get(0).getType() instanceof GraphSchemaType, "data type of graph operators should be %s ", GraphSchemaType.class); GraphSchemaType schemaType = (GraphSchemaType) fields.get(0).getType(); - List labelTypes = new ArrayList<>(); - if (schemaType instanceof GraphSchemaTypeList) { - ((GraphSchemaTypeList) schemaType).forEach(k -> labelTypes.add(k.getLabelType())); - } else { - labelTypes.add(schemaType.getLabelType()); - } - return labelTypes; + return schemaType.getLabelType(); } private void checkFfiResult(FfiResult res) { diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java index df8395834f38..1245a82bb6d6 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/runtime/proto/Utils.java @@ -192,6 +192,7 @@ public static final OuterExpression.ExprOpr protoOperator(SqlOperator operator) } public static final Common.DataType protoBasicDataType(RelDataType basicType) { + if (basicType instanceof GraphLabelType) return Common.DataType.INT32; switch (basicType.getSqlTypeName()) { case NULL: return Common.DataType.NONE; @@ -243,17 +244,10 @@ public static final DataType.IrDataType protoIrDataType( DataType.GraphDataType.Builder builder = DataType.GraphDataType.newBuilder(); builder.setElementOpt( protoElementOpt(((GraphSchemaType) dataType).getScanOpt())); - if (dataType instanceof GraphSchemaTypeList) { - ((GraphSchemaTypeList) dataType) - .forEach( - k -> { - builder.addGraphDataType( - protoElementType(k, isColumnId)); - }); - } else { - builder.addGraphDataType( - protoElementType((GraphSchemaType) dataType, isColumnId)); - } + ((GraphSchemaType) dataType) + .getSchemaTypeAsList() + .forEach( + k -> builder.addGraphDataType(protoElementType(k, isColumnId))); return DataType.IrDataType.newBuilder().setGraphType(builder.build()).build(); } throw new UnsupportedOperationException( @@ -322,14 +316,17 @@ public static final DataType.GraphDataType.GraphElementType protoElementType( public static final DataType.GraphDataType.GraphElementLabel protoElementLabel( GraphLabelType labelType) { + Preconditions.checkArgument( + labelType.getLabelsEntry().size() == 1, + "can not convert label=" + labelType + " to proto 'GraphElementLabel'"); + GraphLabelType.Entry entry = labelType.getSingleLabelEntry(); DataType.GraphDataType.GraphElementLabel.Builder builder = - DataType.GraphDataType.GraphElementLabel.newBuilder() - .setLabel(labelType.getLabelId()); - if (labelType.getSrcLabelId() != null) { - builder.setSrcLabel(Int32Value.of(labelType.getSrcLabelId())); + DataType.GraphDataType.GraphElementLabel.newBuilder().setLabel(entry.getLabelId()); + if (entry.getSrcLabelId() != null) { + builder.setSrcLabel(Int32Value.of(entry.getSrcLabelId())); } - if (labelType.getDstLabelId() != null) { - builder.setDstLabel(Int32Value.of(labelType.getDstLabelId())); + if (entry.getDstLabelId() != null) { + builder.setDstLabel(Int32Value.of(entry.getDstLabelId())); } return builder.build(); } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java index 7cb79604a7a4..6f3d9b3e2f98 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/tools/GraphBuilder.java @@ -418,12 +418,13 @@ public RexGraphVariable variable(@Nullable String alias, String property) { + "]"); } if (property.equals(GraphProperty.LABEL_KEY)) { + GraphSchemaType schemaType = (GraphSchemaType) aliasField.getType(); return RexGraphVariable.of( aliasField.getIndex(), new GraphProperty(GraphProperty.Opt.LABEL), columnField.left, varName, - getTypeFactory().createSqlType(SqlTypeName.CHAR)); + schemaType.getLabelType()); } else if (property.equals(GraphProperty.ID_KEY)) { return RexGraphVariable.of( aliasField.getIndex(), diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphLabelType.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphLabelType.java index aab24c674cd6..2c022e6b669b 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphLabelType.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphLabelType.java @@ -16,81 +16,165 @@ package com.alibaba.graphscope.common.ir.type; +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.sql.type.AbstractSqlType; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.commons.lang3.ObjectUtils; import org.apache.commons.lang3.StringUtils; import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.Collections; +import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * Maintain label for each Entity or Relation: Entity(label), Relation(Label, srcLabel, dstLabel). */ -public class GraphLabelType { - public static GraphLabelType DEFAULT = new GraphLabelType(); - - private String label; - private Integer labelId; - @Nullable private String srcLabel; - @Nullable private Integer srcLabelId; - @Nullable private String dstLabel; - @Nullable private Integer dstLabelId; - - public GraphLabelType() { - this.label = StringUtils.EMPTY; - this.labelId = -1; - } - - public GraphLabelType label(String label) { - Objects.requireNonNull(label); - this.label = label; - return this; - } - - public GraphLabelType labelId(int labelId) { - this.labelId = labelId; - return this; - } +public class GraphLabelType extends AbstractSqlType { + private final List labels; - public GraphLabelType srcLabel(String srcLabel) { - this.srcLabel = srcLabel; - return this; + public GraphLabelType(Entry label) { + this(ImmutableList.of(label)); } - public GraphLabelType srcLabelId(int srcLabelId) { - this.srcLabelId = srcLabelId; - return this; + public GraphLabelType(List labels) { + this(labels, SqlTypeName.CHAR); } - public GraphLabelType dstLabel(String dstLabel) { - this.dstLabel = dstLabel; - return this; + public GraphLabelType(Entry label, SqlTypeName typeName) { + this(ImmutableList.of(label), typeName); } - public GraphLabelType dstLabelId(int dstLabelId) { - this.dstLabelId = dstLabelId; - return this; + public GraphLabelType(List labels, SqlTypeName typeName) { + super(typeName, false, null); + this.labels = ObjectUtils.requireNonEmpty(labels); + this.computeDigest(); } - public String getLabel() { - return label; + public Entry getSingleLabelEntry() { + return labels.get(0); } - public Integer getLabelId() { - return labelId; + public List getLabelsEntry() { + return Collections.unmodifiableList(labels); } - public @Nullable String getSrcLabel() { - return srcLabel; + public List getLabelsString() { + return getLabelsEntry().stream() + .map(k -> k.toString()) + .collect(Collectors.toUnmodifiableList()); } - public @Nullable Integer getSrcLabelId() { - return srcLabelId; + @Override + protected void generateTypeString(StringBuilder stringBuilder, boolean b) { + stringBuilder.append(getLabelsString()); } - public @Nullable String getDstLabel() { - return dstLabel; + public void removeLabels(List labelsToRemove) { + this.labels.removeAll(labelsToRemove); } - public @Nullable Integer getDstLabelId() { - return dstLabelId; + public static class Entry { + private String label; + private Integer labelId; + @Nullable private String srcLabel; + @Nullable private Integer srcLabelId; + @Nullable private String dstLabel; + @Nullable private Integer dstLabelId; + + public Entry() { + this.label = StringUtils.EMPTY; + this.labelId = -1; + } + + public Entry label(String label) { + Objects.requireNonNull(label); + this.label = label; + return this; + } + + public Entry labelId(int labelId) { + this.labelId = labelId; + return this; + } + + public Entry srcLabel(String srcLabel) { + this.srcLabel = srcLabel; + return this; + } + + public Entry srcLabelId(int srcLabelId) { + this.srcLabelId = srcLabelId; + return this; + } + + public Entry dstLabel(String dstLabel) { + this.dstLabel = dstLabel; + return this; + } + + public Entry dstLabelId(int dstLabelId) { + this.dstLabelId = dstLabelId; + return this; + } + + public String getLabel() { + return label; + } + + public Integer getLabelId() { + return labelId; + } + + public @Nullable String getSrcLabel() { + return srcLabel; + } + + public @Nullable Integer getSrcLabelId() { + return srcLabelId; + } + + public @Nullable String getDstLabel() { + return dstLabel; + } + + public @Nullable Integer getDstLabelId() { + return dstLabelId; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + if (srcLabel == null || dstLabel == null) { + builder.append("VertexLabel("); + builder.append(label); + builder.append(")"); + } else { + builder.append("EdgeLabel("); + builder.append(label + ", " + srcLabel + ", " + dstLabel); + builder.append(")"); + } + return builder.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Entry entry = (Entry) o; + return Objects.equals(label, entry.label) + && Objects.equals(labelId, entry.labelId) + && Objects.equals(srcLabel, entry.srcLabel) + && Objects.equals(srcLabelId, entry.srcLabelId) + && Objects.equals(dstLabel, entry.dstLabel) + && Objects.equals(dstLabelId, entry.dstLabelId); + } + + @Override + public int hashCode() { + return Objects.hash(label, labelId, srcLabel, srcLabelId, dstLabel, dstLabelId); + } } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java index a4beb93869fa..acd4e33df0c2 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaType.java @@ -17,23 +17,27 @@ package com.alibaba.graphscope.common.ir.type; import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import org.apache.calcite.linq4j.Ord; -import org.apache.calcite.rel.type.RelDataTypeFamily; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.rel.type.RelRecordType; -import org.apache.calcite.rel.type.StructKind; +import org.apache.calcite.rel.type.*; +import org.apache.commons.lang3.ObjectUtils; +import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * Denote DataType of an entity or a relation, including opt, label and attributes */ public class GraphSchemaType extends RelRecordType { - protected GraphOpt.Source scanOpt; - protected GraphLabelType labelType; + private final GraphOpt.Source scanOpt; + private final GraphLabelType labelType; + private final List fuzzySchemaTypes; /** * @param scanOpt entity or relation @@ -45,11 +49,6 @@ public GraphSchemaType( this(scanOpt, labelType, fields, false); } - protected GraphSchemaType( - GraphOpt.Source scanOpt, List fields, boolean isNullable) { - this(scanOpt, GraphLabelType.DEFAULT, fields, isNullable); - } - /** * add a constructor to accept {@code isNullable}, a nullable GraphSchemaType will be created after left outer join * @param scanOpt @@ -62,11 +61,77 @@ public GraphSchemaType( GraphLabelType labelType, List fields, boolean isNullable) { + this(scanOpt, labelType, fields, ImmutableList.of(), isNullable); + } + + protected GraphSchemaType( + GraphOpt.Source scanOpt, + GraphLabelType labelType, + List fields, + List fuzzySchemaTypes, + boolean isNullable) { super(StructKind.NONE, fields, isNullable); this.scanOpt = scanOpt; + this.fuzzySchemaTypes = Objects.requireNonNull(fuzzySchemaTypes); this.labelType = labelType; } + public static GraphSchemaType create( + List list, RelDataTypeFactory typeFactory) { + return create(list, typeFactory, false); + } + + public static GraphSchemaType create( + List list, RelDataTypeFactory typeFactory, boolean isNullable) { + ObjectUtils.requireNonEmpty(list, "schema type list should not be empty"); + if (list.size() == 1) { + return list.get(0); + } + GraphOpt.Source scanOpt = list.get(0).getScanOpt(); + List labelOpts = Lists.newArrayList(); + List fields = Lists.newArrayList(); + List commonFields = Lists.newArrayList(list.get(0).getFieldList()); + List fuzzyEntries = Lists.newArrayList(); + for (GraphSchemaType type : list) { + Preconditions.checkArgument( + !type.fuzzy(), + "fuzzy label types nested in list of " + + GraphSchemaType.class + + " is considered to be invalid here"); + labelOpts.add( + "{label=" + + type.getLabelType().getLabelsString() + + ", opt=" + + type.scanOpt + + "}"); + if (type.getScanOpt() != scanOpt) { + throw new IllegalArgumentException( + "fuzzy label types should have the same opt, but is " + labelOpts); + } + fields.addAll(type.getFieldList()); + commonFields.retainAll(type.getFieldList()); + fuzzyEntries.addAll(type.getLabelType().getLabelsEntry()); + } + fields = + fields.stream() + .distinct() + .map( + k -> { + if (!commonFields.contains( + k)) { // can be optional for some labels + return new RelDataTypeFieldImpl( + k.getName(), + k.getIndex(), + typeFactory.createTypeWithNullability( + k.getType(), true)); + } + return k; + }) + .collect(Collectors.toList()); + return new GraphSchemaType( + scanOpt, new GraphLabelType(fuzzyEntries), fields, list, isNullable); + } + public GraphOpt.Source getScanOpt() { return scanOpt; } @@ -119,4 +184,14 @@ public boolean isStruct() { public RelDataTypeFamily getFamily() { return scanOpt; } + + public List getSchemaTypeAsList() { + return ObjectUtils.isEmpty(this.fuzzySchemaTypes) + ? ImmutableList.of(this) + : Collections.unmodifiableList(this.fuzzySchemaTypes); + } + + public boolean fuzzy() { + return this.labelType.getLabelsEntry().size() > 1 || this.fuzzySchemaTypes.size() > 1; + } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaTypeList.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaTypeList.java deleted file mode 100644 index 7c464dd27b86..000000000000 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphSchemaTypeList.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright 2020 Alibaba Group Holding Limited. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.alibaba.graphscope.common.ir.type; - -import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; - -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.commons.lang3.ObjectUtils; - -import java.util.*; -import java.util.stream.Collectors; - -/** - * A list of {@code IrSchemaType}, to denote fuzzy conditions in a vertex or an edge, i.e. g.V() or g.V().hasLabel("person", "software") - */ -public class GraphSchemaTypeList extends GraphSchemaType implements List { - private List schemaTypes; - - public static GraphSchemaTypeList create(List list) { - return create(list, false); - } - - public static GraphSchemaTypeList create(List list, boolean isNullable) { - ObjectUtils.requireNonEmpty(list); - GraphOpt.Source scanOpt = list.get(0).getScanOpt(); - List labelOpts = new ArrayList<>(); - List fields = new ArrayList<>(); - for (GraphSchemaType type : list) { - labelOpts.add("{label=" + type.labelType.getLabel() + ", opt=" + type.scanOpt + "}"); - if (type.getScanOpt() != scanOpt) { - throw new IllegalArgumentException( - "fuzzy label types should have the same opt, but is " + labelOpts); - } - fields.addAll(type.getFieldList()); - } - return new GraphSchemaTypeList( - scanOpt, list, fields.stream().distinct().collect(Collectors.toList()), isNullable); - } - - protected GraphSchemaTypeList( - GraphOpt.Source scanOpt, - List schemaTypes, - List fields, - boolean isNullable) { - super(scanOpt, fields, isNullable); - this.schemaTypes = schemaTypes; - } - - @Override - public int size() { - return this.schemaTypes.size(); - } - - @Override - public boolean isEmpty() { - return this.schemaTypes.isEmpty(); - } - - @Override - public boolean contains(Object o) { - return this.schemaTypes.contains(o); - } - - @Override - public Iterator iterator() { - return this.schemaTypes.iterator(); - } - - @Override - public Object[] toArray() { - return this.schemaTypes.toArray(); - } - - @Override - public T[] toArray(T[] a) { - return this.schemaTypes.toArray(a); - } - - @Override - public boolean add(GraphSchemaType graphSchemaType) { - return this.schemaTypes.add(graphSchemaType); - } - - @Override - public boolean remove(Object o) { - return this.schemaTypes.remove(o); - } - - @Override - public boolean containsAll(Collection c) { - return this.schemaTypes.containsAll(c); - } - - @Override - public boolean addAll(Collection c) { - return this.schemaTypes.addAll(c); - } - - @Override - public boolean addAll(int index, Collection c) { - return this.schemaTypes.addAll(index, c); - } - - @Override - public boolean removeAll(Collection c) { - return this.schemaTypes.removeAll(c); - } - - @Override - public boolean retainAll(Collection c) { - return this.schemaTypes.retainAll(c); - } - - @Override - public void clear() { - this.schemaTypes.clear(); - } - - @Override - public GraphSchemaType get(int index) { - return this.schemaTypes.get(index); - } - - @Override - public GraphSchemaType set(int index, GraphSchemaType element) { - return this.schemaTypes.set(index, element); - } - - @Override - public void add(int index, GraphSchemaType element) { - this.schemaTypes.add(index, element); - } - - @Override - public GraphSchemaType remove(int index) { - return this.schemaTypes.remove(index); - } - - @Override - public int indexOf(Object o) { - return this.schemaTypes.indexOf(o); - } - - @Override - public int lastIndexOf(Object o) { - return this.schemaTypes.lastIndexOf(o); - } - - @Override - public ListIterator listIterator() { - return this.schemaTypes.listIterator(); - } - - @Override - public ListIterator listIterator(int index) { - return this.schemaTypes.listIterator(index); - } - - @Override - public List subList(int fromIndex, int toIndex) { - return this.schemaTypes.subList(fromIndex, toIndex); - } -} diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFactoryImpl.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFactoryImpl.java index 4411b89ee833..b1b96fb0b9c4 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFactoryImpl.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/common/ir/type/GraphTypeFactoryImpl.java @@ -18,7 +18,6 @@ import com.alibaba.graphscope.common.config.Configs; import com.alibaba.graphscope.common.config.FrontendConfig; -import com.google.common.collect.Lists; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.rel.type.RelDataType; @@ -36,19 +35,24 @@ public GraphTypeFactoryImpl(Configs configs) { @Override public RelDataType createTypeWithNullability(RelDataType type, boolean nullable) { RelDataType newType; - if (type instanceof GraphSchemaTypeList) { - GraphSchemaTypeList schemaTypeList = (GraphSchemaTypeList) type; - newType = - GraphSchemaTypeList.create( - Lists.newArrayList(schemaTypeList.listIterator()), nullable); - } else if (type instanceof GraphSchemaType) { + if (type instanceof GraphSchemaType) { GraphSchemaType schemaType = (GraphSchemaType) type; - newType = - new GraphSchemaType( - schemaType.getScanOpt(), - schemaType.getLabelType(), - schemaType.getFieldList(), - nullable); + if (schemaType.getSchemaTypeAsList().size() > 1) { // fuzzy schema type + newType = + new GraphSchemaType( + schemaType.getScanOpt(), + schemaType.getLabelType(), + schemaType.getFieldList(), + schemaType.getSchemaTypeAsList(), + nullable); + } else { + newType = + new GraphSchemaType( + schemaType.getScanOpt(), + schemaType.getLabelType(), + schemaType.getFieldList(), + nullable); + } } else { newType = super.createTypeWithNullability(type, nullable); } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java index 21a7b1e8d899..fec064c5c98d 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/antlr4/visitor/ExpressionVisitor.java @@ -22,6 +22,7 @@ import com.alibaba.graphscope.common.ir.tools.GraphBuilder; import com.alibaba.graphscope.common.ir.tools.GraphRexBuilder; import com.alibaba.graphscope.common.ir.tools.GraphStdOperatorTable; +import com.alibaba.graphscope.common.ir.tools.config.GraphOpt; import com.alibaba.graphscope.common.ir.type.GraphProperty; import com.alibaba.graphscope.common.ir.type.GraphSchemaType; import com.alibaba.graphscope.cypher.antlr4.visitor.type.ExprVisitorResult; @@ -320,6 +321,24 @@ public ExprVisitorResult visitOC_SimpleFunction( List exprCtx = ctx.oC_Expression(); String functionName = ctx.oC_FunctionName().getText(); switch (functionName.toUpperCase()) { + case "LABELS": + RexNode labelVar = builder.variable(exprCtx.get(0).getText()); + Preconditions.checkArgument( + labelVar.getType() instanceof GraphSchemaType + && ((GraphSchemaType) labelVar.getType()).getScanOpt() + == GraphOpt.Source.VERTEX, + "'labels' can only be applied on vertex type"); + return new ExprVisitorResult( + builder.variable(exprCtx.get(0).getText(), GraphProperty.LABEL_KEY)); + case "TYPE": + RexNode typeVar = builder.variable(exprCtx.get(0).getText()); + Preconditions.checkArgument( + typeVar.getType() instanceof GraphSchemaType + && ((GraphSchemaType) typeVar.getType()).getScanOpt() + == GraphOpt.Source.EDGE, + "'type' can only be applied on edge type"); + return new ExprVisitorResult( + builder.variable(exprCtx.get(0).getText(), GraphProperty.LABEL_KEY)); case "LENGTH": Preconditions.checkArgument( !exprCtx.isEmpty(), "LENGTH function should have one argument"); @@ -361,6 +380,8 @@ public ExprVisitorResult visitOC_SimpleFunction( private FunctionType getFunctionType(String functionName) { switch (functionName.toUpperCase()) { + case "LABELS": + case "TYPE": case "LENGTH": case "HEAD": return FunctionType.SIMPLE; diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/QueryContext.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/QueryContext.java similarity index 94% rename from interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/QueryContext.java rename to interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/QueryContext.java index 2af48d535e48..0b14a36a77e6 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/QueryContext.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/QueryContext.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.alibaba.graphscope.cypher.integration.suite.ldbc; +package com.alibaba.graphscope.cypher.integration.suite; import java.util.Collections; import java.util.List; diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/LdbcQueries.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/LdbcQueries.java index 342d7d43cacd..6db3f417c98f 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/LdbcQueries.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/LdbcQueries.java @@ -16,6 +16,8 @@ package com.alibaba.graphscope.cypher.integration.suite.ldbc; +import com.alibaba.graphscope.cypher.integration.suite.QueryContext; + import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/SimpleMatchQueries.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/simple/SimpleMatchQueries.java similarity index 90% rename from interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/SimpleMatchQueries.java rename to interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/simple/SimpleMatchQueries.java index b5b89c89d98d..6470b3813ff1 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/ldbc/SimpleMatchQueries.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/integration/suite/simple/SimpleMatchQueries.java @@ -13,7 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.alibaba.graphscope.cypher.integration.suite.ldbc; +package com.alibaba.graphscope.cypher.integration.suite.simple; + +import com.alibaba.graphscope.cypher.integration.suite.QueryContext; import java.util.Arrays; import java.util.List; @@ -128,4 +130,13 @@ public static QueryContext get_simple_match_query_9_test() { "Record<{postId: 33042}>"); return new QueryContext(query, expected); } + + public static QueryContext get_simple_match_query_10_test() { + String query = + "MATCH( a {id:933})-[b]-(c {id: 2199023256077}) return labels(a) AS" + + " vertexLabelName, type(b) AS edgeLabelName;"; + List expected = + Arrays.asList("Record<{vertexLabelName: \"PERSON\", edgeLabelName: \"KNOWS\"}>"); + return new QueryContext(query, expected); + } } diff --git a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/result/CypherRecordParser.java b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/result/CypherRecordParser.java index 0e0389970aa3..7b394d2df623 100644 --- a/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/result/CypherRecordParser.java +++ b/interactive_engine/compiler/src/main/java/com/alibaba/graphscope/cypher/result/CypherRecordParser.java @@ -19,7 +19,6 @@ import com.alibaba.graphscope.common.ir.type.GraphLabelType; import com.alibaba.graphscope.common.ir.type.GraphPathType; import com.alibaba.graphscope.common.ir.type.GraphSchemaType; -import com.alibaba.graphscope.common.ir.type.GraphSchemaTypeList; import com.alibaba.graphscope.common.result.RecordParser; import com.alibaba.graphscope.gaia.proto.Common; import com.alibaba.graphscope.gaia.proto.IrResult; @@ -189,6 +188,9 @@ protected AnyValue parseGraphPath(IrResult.GraphPath path, @Nullable RelDataType } protected AnyValue parseValue(Common.Value value, @Nullable RelDataType dataType) { + if (dataType instanceof GraphLabelType) { + return Values.stringValue(parseLabelValue(value, (GraphLabelType) dataType)); + } switch (value.getItemCase()) { case BOOLEAN: return value.getBoolean() ? BooleanValue.TRUE : BooleanValue.FALSE; @@ -224,18 +226,20 @@ protected AnyValue parseValue(Common.Value value, @Nullable RelDataType dataType } } - private String getLabelName(Common.NameOrId nameOrId, List labelTypes) { + private String getLabelName(Common.NameOrId nameOrId, @Nullable GraphLabelType labelTypes) { switch (nameOrId.getItemCase()) { case NAME: return nameOrId.getName(); case ID: default: List labelIds = new ArrayList<>(); - for (GraphLabelType labelType : labelTypes) { - if (labelType.getLabelId() == nameOrId.getId()) { - return labelType.getLabel(); + if (labelTypes != null) { + for (GraphLabelType.Entry labelType : labelTypes.getLabelsEntry()) { + if (labelType.getLabelId() == nameOrId.getId()) { + return labelType.getLabel(); + } + labelIds.add(labelType.getLabelId()); } - labelIds.add(labelType.getLabelId()); } logger.warn( "label id={} not found, expected ids are {}", nameOrId.getId(), labelIds); @@ -243,16 +247,18 @@ private String getLabelName(Common.NameOrId nameOrId, List label } } - private String getSrcLabelName(Common.NameOrId nameOrId, List labelTypes) { + private String getSrcLabelName(Common.NameOrId nameOrId, @Nullable GraphLabelType labelTypes) { switch (nameOrId.getItemCase()) { case NAME: return nameOrId.getName(); case ID: default: List labelIds = new ArrayList<>(); - for (GraphLabelType labelType : labelTypes) { - if (labelType.getSrcLabelId() == nameOrId.getId()) { - return labelType.getSrcLabel(); + if (labelTypes != null) { + for (GraphLabelType.Entry labelType : labelTypes.getLabelsEntry()) { + if (labelType.getSrcLabelId() == nameOrId.getId()) { + return labelType.getSrcLabel(); + } } } logger.warn( @@ -263,16 +269,18 @@ private String getSrcLabelName(Common.NameOrId nameOrId, List la } } - private String getDstLabelName(Common.NameOrId nameOrId, List labelTypes) { + private String getDstLabelName(Common.NameOrId nameOrId, @Nullable GraphLabelType labelTypes) { switch (nameOrId.getItemCase()) { case NAME: return nameOrId.getName(); case ID: default: List labelIds = new ArrayList<>(); - for (GraphLabelType labelType : labelTypes) { - if (labelType.getDstLabelId() == nameOrId.getId()) { - return labelType.getDstLabel(); + if (labelTypes != null) { + for (GraphLabelType.Entry labelType : labelTypes.getLabelsEntry()) { + if (labelType.getDstLabelId() == nameOrId.getId()) { + return labelType.getDstLabel(); + } } } logger.warn( @@ -283,18 +291,12 @@ private String getDstLabelName(Common.NameOrId nameOrId, List la } } - private List getLabelTypes(RelDataType dataType) { - List labelTypes = Lists.newArrayList(); - if (dataType instanceof GraphSchemaTypeList) { - ((GraphSchemaTypeList) dataType) - .forEach( - k -> { - labelTypes.add(k.getLabelType()); - }); - } else if (dataType instanceof GraphSchemaType) { - labelTypes.add(((GraphSchemaType) dataType).getLabelType()); + private @Nullable GraphLabelType getLabelTypes(RelDataType dataType) { + if (dataType instanceof GraphSchemaType) { + return ((GraphSchemaType) dataType).getLabelType(); + } else { + return null; } - return labelTypes; } private RelDataType getVertexType(RelDataType graphPathType) { @@ -308,4 +310,35 @@ private RelDataType getEdgeType(RelDataType graphPathType) { ? ((GraphPathType) graphPathType).getComponentType().getExpandType() : graphPathType; } + + private String parseLabelValue(Common.Value value, GraphLabelType type) { + switch (value.getItemCase()) { + case STR: + return value.getStr(); + case I32: + return parseLabelValue(value.getI32(), type); + case I64: + return parseLabelValue(value.getI64(), type); + default: + throw new IllegalArgumentException( + "cannot parse label value with type=" + value.getItemCase().name()); + } + } + + private String parseLabelValue(long labelId, GraphLabelType type) { + List expectedLabelIds = Lists.newArrayList(); + for (GraphLabelType.Entry entry : type.getLabelsEntry()) { + if (entry.getLabelId() == labelId) { + return entry.getLabel(); + } + expectedLabelIds.add(entry.getLabelId()); + } + throw new IllegalArgumentException( + "cannot parse label value=" + + labelId + + " from expected type=" + + type + + ", expected ids are " + + expectedLabelIds); + } } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java index 48f4324edd62..bc6a3263dd06 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/antlr4/MatchTest.java @@ -275,4 +275,20 @@ public void match_14_test() { SqlTypeName.CHAR, node.getRowType().getFieldList().get(0).getType().getSqlTypeName()); } + + @Test + public void match_15_test() { + RelNode node = Utils.eval("Match (a)-[b]-(c) Return labels(a), type(b)").build(); + Assert.assertEquals( + "GraphLogicalProject(~label=[a.~label], ~label0=[b.~label], isAppend=[false])\n" + + " GraphLogicalSingleMatch(input=[null]," + + " sentence=[GraphLogicalGetV(tableConfig=[{isAll=true, tables=[software," + + " person]}], alias=[c], opt=[OTHER])\n" + + " GraphLogicalExpand(tableConfig=[{isAll=true, tables=[created, knows]}]," + + " alias=[b], opt=[BOTH])\n" + + " GraphLogicalSource(tableConfig=[{isAll=true, tables=[software," + + " person]}], alias=[a], opt=[VERTEX])\n" + + "], matchOpt=[INNER])", + node.explain().trim()); + } } diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/IrLdbcTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/IrLdbcTest.java index c13694b8815b..5e4ca8c58bb0 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/IrLdbcTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/IrLdbcTest.java @@ -18,8 +18,8 @@ import static org.junit.Assume.assumeTrue; +import com.alibaba.graphscope.cypher.integration.suite.QueryContext; import com.alibaba.graphscope.cypher.integration.suite.ldbc.LdbcQueries; -import com.alibaba.graphscope.cypher.integration.suite.ldbc.QueryContext; import org.junit.AfterClass; import org.junit.Assert; diff --git a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/SimpleMatchTest.java b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/SimpleMatchTest.java index 693f313ea82b..82cac878ac18 100644 --- a/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/SimpleMatchTest.java +++ b/interactive_engine/compiler/src/test/java/com/alibaba/graphscope/cypher/integration/ldbc/SimpleMatchTest.java @@ -15,8 +15,8 @@ */ package com.alibaba.graphscope.cypher.integration.ldbc; -import com.alibaba.graphscope.cypher.integration.suite.ldbc.QueryContext; -import com.alibaba.graphscope.cypher.integration.suite.ldbc.SimpleMatchQueries; +import com.alibaba.graphscope.cypher.integration.suite.QueryContext; +import com.alibaba.graphscope.cypher.integration.suite.simple.SimpleMatchQueries; import org.junit.AfterClass; import org.junit.Assert; @@ -100,6 +100,13 @@ public void run_simple_match_9_test() { Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString()); } + @Test + public void run_simple_match_10_test() { + QueryContext testQuery = SimpleMatchQueries.get_simple_match_query_10_test(); + Result result = session.run(testQuery.getQuery()); + Assert.assertEquals(testQuery.getExpectedResult().toString(), result.list().toString()); + } + @AfterClass public static void afterClass() { if (session != null) {