Skip to content

Commit

Permalink
Random feature layout in simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
Devesh Sarda committed Feb 12, 2024
1 parent 90807f0 commit 61d509a
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,5 @@ Thumbs.db

src/cpp/third_party
test_datasets
simulator/datasets
simulator/datasets
simulator/images
1 change: 1 addition & 0 deletions simulator/configs/arvix.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"dataset_name" : "ogbn_arxiv",
"features_stats" : {
"feature_layout" : "random",
"page_size" : "16 KB",
"feature_dimension" : 128,
"feature_size" : "float32"
Expand Down
Binary file removed simulator/images/arvix.png
Binary file not shown.
11 changes: 5 additions & 6 deletions simulator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ def read_config_file(config_file):

def read_arguments():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("config_file", type=str, help="The config file containing the details for the simulation")
parser.add_argument("--config_file", type=str, help="The config file containing the details for the simulation")
parser.add_argument("--save_path", required=True, type=str, help="The path to save the resulting image to")
parser.add_argument("--graph_title", required=True, type=str, help="The title of the saved graph")
return parser.parse_args()


IMAGES_SAVE_DIR = "images"


def main():
arguments = read_arguments()
config = read_config_file(arguments.config_file)
Expand All @@ -44,8 +43,8 @@ def main():
print("Got result for", len(pages_loaded), "nodes out of", len(nodes_to_sample), "nodes")

# Save the histogram
save_path = os.path.join(IMAGES_SAVE_DIR, os.path.basename(arguments.config_file).replace("json", "png"))
visualize_results(pages_loaded, save_path, config["dataset_name"])
os.makedirs(os.path.dirname(arguments.save_path), exist_ok=True)
visualize_results(pages_loaded, arguments.save_path, arguments.graph_title, config["dataset_name"])


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion simulator/src/features_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import humanfriendly
import os
import math
import random


class FeaturesLoader:
Expand All @@ -11,9 +12,16 @@ def __init__(self, data_loader, features_stat):
self.node_feature_size = self.feature_size * features_stat["feature_dimension"]

self.nodes_per_page = max(int(self.page_size / self.node_feature_size), 1)
self.total_pages = int(math.ceil(data_loader.get_num_nodes() / (1.0 * self.nodes_per_page)))
total_nodes = data_loader.get_num_nodes()
self.total_pages = int(math.ceil(total_nodes / (1.0 * self.nodes_per_page)))

self.node_location_map = [i for i in range(total_nodes)]
if "feature_layout" in features_stat and features_stat["feature_layout"] == "random":
random.shuffle(self.node_location_map)
print(self.node_location_map[:10])

def get_node_page(self, node_id):
node_location = self.node_location_map[node_id]
return int(node_id / self.nodes_per_page)

def get_total_file_size(self):
Expand Down
4 changes: 2 additions & 2 deletions simulator/src/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import os


def visualize_results(pages_loaded, save_path, dataset_name, num_bins=50):
def visualize_results(pages_loaded, save_path, graph_title, dataset_name, num_bins=50):
# Create the histogram
plt.figure()
plt.ecdf(pages_loaded, label="CDF")
plt.hist(pages_loaded, bins=num_bins, histtype="step", density=True, cumulative=True, label="Cumulative histogram")
plt.xlabel("Number of pages loaded for node inference")
plt.ylabel("Percentage of nodes")
plt.title("Number of pages loaded for node inference on " + dataset_name)
plt.title(graph_title)
plt.xlim(0, 50)
plt.legend()

Expand Down

0 comments on commit 61d509a

Please sign in to comment.