diff --git a/flex/engines/graph_db/database/graph_db.cc b/flex/engines/graph_db/database/graph_db.cc index 1b8fff41a8f6..55601315c43b 100644 --- a/flex/engines/graph_db/database/graph_db.cc +++ b/flex/engines/graph_db/database/graph_db.cc @@ -305,7 +305,12 @@ const Schema& GraphDB::schema() const { return graph_.schema(); } std::shared_ptr GraphDB::get_vertex_property_column( uint8_t label, const std::string& col_name) const { - return graph_.get_vertex_table(label).get_column(col_name); + return graph_.get_vertex_property_column(label, col_name); +} + +std::shared_ptr GraphDB::get_vertex_id_column( + uint8_t label) const { + return graph_.get_vertex_id_column(label); } AppWrapper GraphDB::CreateApp(uint8_t app_type, int thread_id) { diff --git a/flex/engines/graph_db/database/graph_db.h b/flex/engines/graph_db/database/graph_db.h index d345838f7be3..da24423b16d9 100644 --- a/flex/engines/graph_db/database/graph_db.h +++ b/flex/engines/graph_db/database/graph_db.h @@ -137,6 +137,8 @@ class GraphDB { std::shared_ptr get_vertex_property_column( uint8_t label, const std::string& col_name) const; + std::shared_ptr get_vertex_id_column(uint8_t label) const; + AppWrapper CreateApp(uint8_t app_type, int thread_id); void GetAppInfo(Encoder& result); diff --git a/flex/engines/graph_db/database/graph_db_session.cc b/flex/engines/graph_db/database/graph_db_session.cc index 8173fa65f5f8..ed128bf32605 100644 --- a/flex/engines/graph_db/database/graph_db_session.cc +++ b/flex/engines/graph_db/database/graph_db_session.cc @@ -79,33 +79,7 @@ std::shared_ptr GraphDBSession::get_vertex_property_column( std::shared_ptr GraphDBSession::get_vertex_id_column( uint8_t label) const { - if (db_.graph().lf_indexers_[label].get_type() == PropertyType::kInt64) { - return std::make_shared>( - dynamic_cast&>( - db_.graph().lf_indexers_[label].get_keys())); - } else if (db_.graph().lf_indexers_[label].get_type() == - PropertyType::kInt32) { - return std::make_shared>( - dynamic_cast&>( - db_.graph().lf_indexers_[label].get_keys())); - } else if (db_.graph().lf_indexers_[label].get_type() == - PropertyType::kUInt64) { - return std::make_shared>( - dynamic_cast&>( - db_.graph().lf_indexers_[label].get_keys())); - } else if (db_.graph().lf_indexers_[label].get_type() == - PropertyType::kUInt32) { - return std::make_shared>( - dynamic_cast&>( - db_.graph().lf_indexers_[label].get_keys())); - } else if (db_.graph().lf_indexers_[label].get_type() == - PropertyType::kStringView) { - return std::make_shared>( - dynamic_cast&>( - db_.graph().lf_indexers_[label].get_keys())); - } else { - return nullptr; - } + return db_.get_vertex_id_column(label); } Result> GraphDBSession::Eval(const std::string& input) { diff --git a/flex/engines/graph_db/database/read_transaction.h b/flex/engines/graph_db/database/read_transaction.h index 23b93acf3fe3..ef352a6ca903 100644 --- a/flex/engines/graph_db/database/read_transaction.h +++ b/flex/engines/graph_db/database/read_transaction.h @@ -290,11 +290,41 @@ class ReadTransaction { const MutablePropertyFragment& graph() const; + /* + * @brief Get the handle of the vertex property column, only for non-primary + * key columns. + */ const std::shared_ptr get_vertex_property_column( uint8_t label, const std::string& col_name) const { return graph_.get_vertex_table(label).get_column(col_name); } + /** + * @brief Get the handle of the vertex property column, including the primary + * key. + * @tparam T The type of the column. + * @param label The label of the vertex. + * @param col_name The name of the column. + */ + template + const std::shared_ptr> get_vertex_ref_property_column( + uint8_t label, const std::string& col_name) const { + auto pk = graph_.schema().get_vertex_primary_key(label); + CHECK(pk.size() == 1) << "Only support single primary key"; + if (col_name == std::get<1>(pk[0])) { + return std::dynamic_pointer_cast>( + graph_.get_vertex_id_column(label)); + } else { + auto ptr = graph_.get_vertex_table(label).get_column(col_name); + if (ptr) { + return std::dynamic_pointer_cast>( + CreateRefColumn(ptr)); + } else { + return nullptr; + } + } + } + class vertex_iterator { public: vertex_iterator(label_t label, vid_t cur, vid_t num, diff --git a/flex/engines/graph_db/runtime/adhoc/var.cc b/flex/engines/graph_db/runtime/adhoc/var.cc index 3f581aa80d4f..b8cc6e7161bc 100644 --- a/flex/engines/graph_db/runtime/adhoc/var.cc +++ b/flex/engines/graph_db/runtime/adhoc/var.cc @@ -56,25 +56,9 @@ Var::Var(const ReadTransaction& txn, const Context& ctx, if (pt.has_id()) { getter_ = std::make_shared(ctx, tag); } else if (pt.has_key()) { - if (pt.key().name() == "id") { - if (type_ == RTAnyType::kStringValue) { - getter_ = - std::make_shared>( - txn, ctx, tag); - } else if (type_ == RTAnyType::kI32Value) { - getter_ = std::make_shared>( - txn, ctx, tag); - } else if (type_ == RTAnyType::kI64Value) { - getter_ = std::make_shared>( - txn, ctx, tag); - } else { - LOG(FATAL) << "not support for " - << static_cast(type_.type_enum_); - } - } else { - getter_ = create_vertex_property_path_accessor(txn, ctx, tag, type_, - pt.key().name()); - } + getter_ = create_vertex_property_path_accessor(txn, ctx, tag, type_, + pt.key().name()); + } else if (pt.has_label()) { getter_ = create_vertex_label_path_accessor(ctx, tag); } else { @@ -126,23 +110,8 @@ Var::Var(const ReadTransaction& txn, const Context& ctx, if (pt.has_id()) { getter_ = std::make_shared(); } else if (pt.has_key()) { - if (pt.key().name() == "id") { - if (type_ == RTAnyType::kStringValue) { - getter_ = - std::make_shared>( - txn); - } else if (type_ == RTAnyType::kI32Value) { - getter_ = std::make_shared>(txn); - } else if (type_ == RTAnyType::kI64Value) { - getter_ = std::make_shared>(txn); - } else { - LOG(FATAL) << "not support for " - << static_cast(type_.type_enum_); - } - } else { - getter_ = create_vertex_property_vertex_accessor(txn, type_, - pt.key().name()); - } + getter_ = create_vertex_property_vertex_accessor(txn, type_, + pt.key().name()); } else if (pt.has_label()) { getter_ = std::make_shared(); } else { diff --git a/flex/engines/graph_db/runtime/common/accessors.h b/flex/engines/graph_db/runtime/common/accessors.h index 33a468a7155d..67f5e94534f4 100644 --- a/flex/engines/graph_db/runtime/common/accessors.h +++ b/flex/engines/graph_db/runtime/common/accessors.h @@ -156,9 +156,8 @@ class VertexPropertyPathAccessor : public IAccessor { int label_num = txn.schema().vertex_label_num(); property_columns_.resize(label_num, nullptr); for (int i = 0; i < label_num; ++i) { - property_columns_[i] = dynamic_cast*>( - txn.get_vertex_property_column(static_cast(i), prop_name) - .get()); + property_columns_[i] = txn.template get_vertex_ref_property_column( + static_cast(i), prop_name); } } @@ -205,7 +204,7 @@ class VertexPropertyPathAccessor : public IAccessor { private: const IVertexColumn& vertex_col_; - std::vector*> property_columns_; + std::vector>> property_columns_; }; class VertexLabelPathAccessor : public IAccessor { @@ -323,9 +322,8 @@ class VertexPropertyVertexAccessor : public IAccessor { int label_num = txn.schema().vertex_label_num(); property_columns_.resize(label_num, nullptr); for (int i = 0; i < label_num; ++i) { - property_columns_[i] = dynamic_cast*>( - txn.get_vertex_property_column(static_cast(i), prop_name) - .get()); + property_columns_[i] = txn.template get_vertex_ref_property_column( + static_cast(i), prop_name); } } @@ -366,7 +364,7 @@ class VertexPropertyVertexAccessor : public IAccessor { } private: - std::vector*> property_columns_; + std::vector>> property_columns_; }; class EdgeIdPathAccessor : public IAccessor { diff --git a/flex/engines/graph_db/runtime/common/columns/vertex_columns.h b/flex/engines/graph_db/runtime/common/columns/vertex_columns.h index 108ac48d5e5e..b984e3bd9638 100644 --- a/flex/engines/graph_db/runtime/common/columns/vertex_columns.h +++ b/flex/engines/graph_db/runtime/common/columns/vertex_columns.h @@ -215,6 +215,8 @@ class OptionalSLVertexColumn : public IVertexColumn { ISigColumn* generate_signature() const override; + label_t label() const { return label_; } + private: friend class OptionalSLVertexColumnBuilder; label_t label_; diff --git a/flex/interactive/sdk/python/gs_interactive/tests/conftest.py b/flex/interactive/sdk/python/gs_interactive/tests/conftest.py index 6282fc4a393b..d63048a0f4f3 100644 --- a/flex/interactive/sdk/python/gs_interactive/tests/conftest.py +++ b/flex/interactive/sdk/python/gs_interactive/tests/conftest.py @@ -334,6 +334,20 @@ def create_partial_modern_graph(interactive_session): delete_running_graph(interactive_session, graph_id) +@pytest.fixture(scope="function") +def create_graph_with_custom_pk_name(interactive_session): + modern_graph_custom_pk_name = modern_graph_full.copy() + for vertex_type in modern_graph_custom_pk_name["schema"]["vertex_types"]: + vertex_type["properties"][0]["property_name"] = "custom_id" + vertex_type["primary_keys"] = ["custom_id"] + create_graph_request = CreateGraphRequest.from_dict(modern_graph_custom_pk_name) + resp = interactive_session.create_graph(create_graph_request) + assert resp.is_ok() + graph_id = resp.get_value().graph_id + yield graph_id + delete_running_graph(interactive_session, graph_id) + + def wait_job_finish(sess: Session, job_id: str): assert job_id is not None while True: diff --git a/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py b/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py index eba1763b8975..af8c55a4b654 100644 --- a/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py +++ b/flex/interactive/sdk/python/gs_interactive/tests/test_robustness.py @@ -265,3 +265,27 @@ def test_call_proc_in_cypher(interactive_session, neo4j_session, create_modern_g for record in result: cnt += 1 assert cnt == 8 + + +def test_custom_pk_name( + interactive_session, neo4j_session, create_graph_with_custom_pk_name +): + print("[Test custom pk name]") + import_data_to_full_modern_graph( + interactive_session, create_graph_with_custom_pk_name + ) + start_service_on_graph(interactive_session, create_graph_with_custom_pk_name) + result = neo4j_session.run( + "MATCH (n: person) where n.custom_id = 4 return n.custom_id;" + ) + records = result.fetch(10) + for record in records: + print(record) + assert len(records) == 1 + + result = neo4j_session.run( + "MATCH (n:person)-[e]-(v:person) where v.custom_id = 1 return count(e);" + ) + records = result.fetch(1) + assert len(records) == 1 and records[0]["$f0"] == 2 + start_service_on_graph(interactive_session, "1") diff --git a/flex/storages/rt_mutable_graph/mutable_property_fragment.cc b/flex/storages/rt_mutable_graph/mutable_property_fragment.cc index 1cb2c329efd5..22f736fff8a7 100644 --- a/flex/storages/rt_mutable_graph/mutable_property_fragment.cc +++ b/flex/storages/rt_mutable_graph/mutable_property_fragment.cc @@ -496,4 +496,38 @@ const CsrBase* MutablePropertyFragment::get_ie_csr(label_t label, return ie_[index]; } +std::shared_ptr MutablePropertyFragment::get_vertex_property_column( + uint8_t label, const std::string& prop) const { + return vertex_data_[label].get_column(prop); +} + +std::shared_ptr MutablePropertyFragment::get_vertex_id_column( + uint8_t label) const { + if (lf_indexers_[label].get_type() == PropertyType::kInt64) { + return std::make_shared>( + dynamic_cast&>( + lf_indexers_[label].get_keys())); + } else if (lf_indexers_[label].get_type() == PropertyType::kInt32) { + return std::make_shared>( + dynamic_cast&>( + lf_indexers_[label].get_keys())); + } else if (lf_indexers_[label].get_type() == PropertyType::kUInt64) { + return std::make_shared>( + dynamic_cast&>( + lf_indexers_[label].get_keys())); + } else if (lf_indexers_[label].get_type() == PropertyType::kUInt32) { + return std::make_shared>( + dynamic_cast&>( + lf_indexers_[label].get_keys())); + } else if (lf_indexers_[label].get_type() == PropertyType::kStringView) { + return std::make_shared>( + dynamic_cast&>( + lf_indexers_[label].get_keys())); + } else { + LOG(ERROR) << "Unsupported vertex id type: " + << lf_indexers_[label].get_type(); + return nullptr; + } +} + } // namespace gs diff --git a/flex/storages/rt_mutable_graph/mutable_property_fragment.h b/flex/storages/rt_mutable_graph/mutable_property_fragment.h index d8bccbe55c85..39fdc9a9f285 100644 --- a/flex/storages/rt_mutable_graph/mutable_property_fragment.h +++ b/flex/storages/rt_mutable_graph/mutable_property_fragment.h @@ -112,6 +112,11 @@ class MutablePropertyFragment { void loadSchema(const std::string& filename); + std::shared_ptr get_vertex_property_column( + uint8_t label, const std::string& prop) const; + + std::shared_ptr get_vertex_id_column(uint8_t label) const; + Schema schema_; std::vector lf_indexers_; std::vector ie_, oe_;