Skip to content

Commit

Permalink
Added weights to storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Devesh Sarda committed Dec 27, 2023
1 parent ee171bb commit c04d75a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/cpp/include/storage/graph_storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ struct GraphModelStoragePtrs {
shared_ptr<Storage> edges = nullptr;
shared_ptr<Storage> train_edges = nullptr;
shared_ptr<Storage> train_edges_dst_sort = nullptr;
shared_ptr<Storage> train_edges_weights = nullptr;
shared_ptr<Storage> validation_edges = nullptr;
shared_ptr<Storage> validation_edges_weights = nullptr;
shared_ptr<Storage> test_edges = nullptr;
shared_ptr<Storage> test_edges_weights = nullptr;
shared_ptr<Storage> nodes = nullptr;
shared_ptr<Storage> train_nodes = nullptr;
shared_ptr<Storage> valid_nodes = nullptr;
Expand Down
5 changes: 4 additions & 1 deletion src/cpp/src/storage/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ shared_ptr<GraphModelStorage> initializeStorageLinkPrediction(shared_ptr<Model>
storage_ptrs.train_edges_dst_sort = edge_storages["train_edges_dst_sort"];
storage_ptrs.validation_edges = edge_storages["validation_edges"];
storage_ptrs.test_edges = edge_storages["test_edges"];
storage_ptrs.train_edges_weights = edge_storages["train_edge_weights_storage"];
storage_ptrs.validation_edges_weights = edge_storages["valid_edge_weights_storage"];
storage_ptrs.test_edges_weights = edge_storages["test_edge_weights_storage"];

storage_ptrs.node_features = initializeNodeFeatures(model, storage_config);
storage_ptrs.node_embeddings = std::get<0>(node_embeddings);
Expand All @@ -506,10 +509,10 @@ shared_ptr<GraphModelStorage> initializeStorageNodeClassification(shared_ptr<Mod
shared_ptr<Storage> node_labels = initializeNodeLabels(model, storage_config);

GraphModelStoragePtrs storage_ptrs = {};

storage_ptrs.train_edges = edge_storages["train_edges"];
storage_ptrs.train_edges_dst_sort = edge_storages["train_edges_dst_sort"];
storage_ptrs.edges = storage_ptrs.train_edges;
storage_ptrs.train_edges_weights = edge_storages["train_edge_weights_storage"];

storage_ptrs.train_nodes = std::get<0>(node_id_storages);
storage_ptrs.valid_nodes = std::get<1>(node_id_storages);
Expand Down

0 comments on commit c04d75a

Please sign in to comment.