Skip to content

Commit

Permalink
Weighted storage sort
Browse files Browse the repository at this point in the history
  • Loading branch information
Devesh Sarda committed Dec 14, 2023
1 parent 2f27ffe commit a8a3798
Show file tree
Hide file tree
Showing 8 changed files with 263 additions and 54 deletions.
1 change: 1 addition & 0 deletions src/cpp/include/configuration/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const string dst_sort = "_dst_sort";

const string edges_directory = "edges/";
const string edges_file = "edges";
const string weights = "_weights";
const string edge_partition_offsets_file = "partition_offsets.txt";

const string node_mapping_file = "node_mapping.txt";
Expand Down
8 changes: 4 additions & 4 deletions src/cpp/include/storage/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>

#include <unistd.h>
#include <string>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <map>

#include "common/datatypes.h"
#include "storage/graph_storage.h"
#include "storage/storage.h"

std::tuple<shared_ptr<Storage>, shared_ptr<Storage>, shared_ptr<Storage>, shared_ptr<Storage>> initializeEdges(shared_ptr<StorageConfig> storage_config,
LearningTask learning_task);
std::map<std::string, shared_ptr<Storage>> initializeEdges(shared_ptr<StorageConfig> storage_config, LearningTask learning_task);

std::tuple<shared_ptr<Storage>, shared_ptr<Storage>> initializeNodeEmbeddings(std::shared_ptr<Model> model, shared_ptr<StorageConfig> storage_config,
bool reinitialize, bool train, std::shared_ptr<InitConfig> init_config);
Expand Down
24 changes: 12 additions & 12 deletions src/cpp/include/storage/storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Storage {
torch::Tensor data_;
torch::Device device_;
string filename_;
bool loaded_;

Storage();

Expand All @@ -55,6 +56,8 @@ class Storage {

virtual void indexPut(Indices indices, torch::Tensor values) = 0;

virtual void rangePut(int64_t offset, torch::Tensor values) = 0;

virtual void rangePut(int64_t offset, int64_t n, torch::Tensor values) = 0;

virtual void load() = 0;
Expand All @@ -65,7 +68,7 @@ class Storage {

virtual void shuffle() = 0;

virtual void sort(bool src) = 0;
virtual void sort(bool src, std::shared_ptr<Storage> weight_file = nullptr) = 0;

int64_t getDim0() { return dim0_size_; }

Expand All @@ -88,10 +91,8 @@ class Storage {
/** Storage which uses the partition buffer, used for node embeddings and optimizer state */
class PartitionBufferStorage : public Storage {
public:
bool loaded_;

PartitionBuffer *buffer_;

shared_ptr<PartitionBufferOptions> options_;

PartitionBufferStorage(string filename, int64_t dim0_size, int64_t dim1_size, shared_ptr<PartitionBufferOptions> options);
Expand All @@ -102,7 +103,7 @@ class PartitionBufferStorage : public Storage {

~PartitionBufferStorage();

void rangePut(int64_t offset, torch::Tensor values);
void rangePut(int64_t offset, torch::Tensor values) override;

void append(torch::Tensor values);

Expand All @@ -124,7 +125,7 @@ class PartitionBufferStorage : public Storage {

void shuffle() override;

void sort(bool src) override;
void sort(bool src, std::shared_ptr<Storage> weight_file = nullptr) override;

Indices getRandomIds(int64_t size) { return buffer_->getRandomIds(size); }

Expand All @@ -150,8 +151,6 @@ class FlatFile : public Storage {
private:
int fd_;

bool loaded_;

public:
FlatFile(string filename, int64_t dim0_size, int64_t dim1_size, torch::Dtype dtype, bool alloc = false);

Expand All @@ -161,7 +160,7 @@ class FlatFile : public Storage {

~FlatFile(){};

void rangePut(int64_t offset, torch::Tensor values);
void rangePut(int64_t offset, torch::Tensor values) override;

void append(torch::Tensor values);

Expand All @@ -183,7 +182,7 @@ class FlatFile : public Storage {

void shuffle() override;

void sort(bool src) override;
void sort(bool src, std::shared_ptr<Storage> weight_file = nullptr) override;

void move(string new_filename);

Expand All @@ -199,8 +198,6 @@ class InMemory : public Storage {
private:
int fd_;

bool loaded_;

public:
InMemory(string filename, int64_t dim0_size, int64_t dim1_size, torch::Dtype dtype, torch::Device device);

Expand All @@ -226,11 +223,14 @@ class InMemory : public Storage {

void indexPut(Indices indices, torch::Tensor values) override;

void rangePut(int64_t offset, torch::Tensor values) override;

void rangePut(int64_t offset, int64_t n, torch::Tensor values) override;

void shuffle() override;

void sort(bool src) override;
void sort(bool src, std::shared_ptr<Storage> weight_file = nullptr) override;

};

#endif // MARIUS_STORAGE_H
9 changes: 6 additions & 3 deletions src/cpp/src/marius.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ void marius_train(shared_ptr<MariusConfig> marius_config) {
auto model = std::get<0>(tup);
auto graph_model_storage = std::get<1>(tup);
auto dataloader = std::get<2>(tup);


/*
shared_ptr<Trainer> trainer;
shared_ptr<Evaluator> evaluator;
Expand Down Expand Up @@ -159,7 +160,7 @@ void marius_train(shared_ptr<MariusConfig> marius_config) {
if (marius_config->storage->export_encoded_nodes) {
encode_and_export(dataloader, model, marius_config);
}
}
} */
}

void marius_eval(shared_ptr<MariusConfig> marius_config) {
Expand All @@ -170,6 +171,7 @@ void marius_eval(shared_ptr<MariusConfig> marius_config) {

shared_ptr<Evaluator> evaluator;

/*
if (marius_config->evaluation->epochs_per_eval > 0) {
if (marius_config->evaluation->pipeline->sync) {
evaluator = std::make_shared<SynchronousEvaluator>(dataloader, model);
Expand All @@ -182,6 +184,7 @@ void marius_eval(shared_ptr<MariusConfig> marius_config) {
if (marius_config->storage->export_encoded_nodes) {
encode_and_export(dataloader, model, marius_config);
}
*/
}

void marius(int argc, char *argv[]) {
Expand All @@ -201,7 +204,7 @@ void marius(int argc, char *argv[]) {
marius_train(marius_config);
} else {
marius_eval(marius_config);
}
}
}

int main(int argc, char *argv[]) { marius(argc, argv); }
Loading

0 comments on commit a8a3798

Please sign in to comment.