Skip to content

Commit

Permalink
feat(gnn): create classification reports
Browse files Browse the repository at this point in the history
  • Loading branch information
DerYeger committed Oct 12, 2024
1 parent 5d4e875 commit 38c19f4
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 23 deletions.
23 changes: 18 additions & 5 deletions ml/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ dependencies:
- blas=1.0=openblas
- brotli-python=1.0.9=py310hc377ac9_7
- bzip2=1.0.8=h80987f9_5
- ca-certificates=2024.3.11=hca03da5_0
- certifi=2024.2.2=py310hca03da5_0
- ca-certificates=2024.9.24=hca03da5_0
- certifi=2024.8.30=py310hca03da5_0
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- ffmpeg=4.2.2=h04105a8_0
- freetype=2.12.1=h1192e45_0
Expand Down Expand Up @@ -47,13 +47,13 @@ dependencies:
- numpy-base=1.26.4=py310ha9811e2_0
- openh264=1.8.0=h98b2900_0
- openjpeg=2.3.0=h7a6adac_2
- openssl=3.0.13=h1a28f6b_1
- openssl=3.0.15=h80987f9_0
- pillow=10.2.0=py310h80987f9_0
- pip=23.3.1=py310hca03da5_0
- pybind11-abi=4=hd3eb1b0_1
- pygments=2.15.1=py310hca03da5_1
- pysocks=1.7.1=py310hca03da5_0
- python=3.10.14=hb885b13_0
- pytorch=1.13.1=py3.10_0
- readline=8.2=h1a28f6b_0
- requests=2.31.0=py310hca03da5_1
- rich=13.3.5=py310hca03da5_0
Expand All @@ -63,7 +63,6 @@ dependencies:
- torchaudio=0.13.1=py310_cpu
- torchvision=0.14.1=py310_cpu
- typing_extensions=4.9.0=py310hca03da5_1
- tzdata=2024a=h04d1e81_0
- urllib3=2.1.0=py310hca03da5_1
- wheel=0.41.2=py310hca03da5_0
- x264=1!152.20180806=h1a28f6b_0
Expand All @@ -75,17 +74,31 @@ dependencies:
- aiosignal==1.3.1
- async-timeout==4.0.3
- attrs==23.2.0
- dgl==2.2.0
- filelock==3.14.0
- frozenlist==1.4.1
- fsspec==2024.3.1
- jinja2==3.1.3
- joblib==1.4.0
- markupsafe==2.1.5
- mpmath==1.3.0
- multidict==6.0.5
- networkx==3.3
- pandas==2.2.2
- psutil==5.9.8
- pympler==1.0.1
- pyparsing==3.1.2
- python-dateutil==2.9.0.post0
- pytz==2024.1
- scikit-learn==1.4.2
- scipy==1.13.0
- six==1.16.0
- sympy==1.12.1
- threadpoolctl==3.4.0
- torch==2.3.1
- torch-geometric==2.5.3
- torchdata==0.7.1
- tqdm==4.66.2
- tree-lstm==0.0.8
- tzdata==2024.1
- yarl==1.9.4
79 changes: 79 additions & 0 deletions ml/gnn/src/calculate_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from decimal import Decimal
import json
import os
from utils import script_dir

report_dir = f"{script_dir}/../../.output/gnn/"

def dataset_metrics():
return {
"f1-score": 0,
"precision": 0,
"recall": 0,
"support": 0,
}

def method_metrics():
return {
"accuracy": 0,
"weighted avg": dataset_metrics(),
"macro avg": dataset_metrics(),
}

def model_metrics():
return {
"train": method_metrics(),
"validation": method_metrics(),
"test": method_metrics(),
}
models = {
"gcn": model_metrics(),
"gat": model_metrics(),
}

metrics = [ "f1-score", "precision", "recall", "support"]
methods = ["weighted avg", "macro avg"]

# for every subdir in report_dir
# for every file in subdir
# read file
num_seeds = 0
for seed_dir in os.listdir(report_dir):
seed_dir_path = os.path.join(report_dir, seed_dir)
if not os.path.isdir(seed_dir_path):
continue
num_seeds += 1
for model_dir in os.listdir(seed_dir_path):
model_dir_path = os.path.join(seed_dir_path, model_dir)
if not os.path.isdir(model_dir_path):
continue
for report_file in os.listdir(model_dir_path):
report_file_path = os.path.join(model_dir_path, report_file)
if not os.path.isfile(report_file_path):
continue
with open(report_file_path, "r") as f:
serialized = f.read()
deserialized = json.loads(serialized)
report_name = report_file.replace(".json", "")
models[model_dir][report_name]["accuracy"] += deserialized["accuracy"]
for metric in metrics:
for method in methods:
models[model_dir][report_name][method][metric] += deserialized[method][metric]

for model in models:
for dataset in models[model]:
models[model][dataset]["accuracy"] = float(
round(Decimal(models[model][dataset]["accuracy"] / num_seeds), 3)
)
for method in methods:
for metric in metrics:
models[model][dataset][method][metric] = float(
round(
Decimal(models[model][dataset][method][metric] / num_seeds), 3
)
)

final = json.dumps(models, indent=4)
print(final)
with open(f"{report_dir}/final_report.json", "w") as file:
file.write(final)
4 changes: 2 additions & 2 deletions ml/gnn/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, name: str, dataset_file: str, layout: Layout):
def load(self):
dataset_load_start_time = time.perf_counter()
if self.is_cached:
self.data, self.slices, self.num_nodes, self.metadata = torch.load(
self.data, self.slices, self.num_nodes, self.metadata, self.node_counts, self.num_nodes = torch.load(
self.dataset_cache_file
)
self.to(device)
Expand All @@ -53,7 +53,7 @@ def load(self):
self.node_counts = [len(data.x) for data in data_entries]
self.num_nodes = sum(self.node_counts)
torch.save(
(base_data, slices, self.num_nodes, self.metadata),
(base_data, slices, self.num_nodes, self.metadata, self.node_counts, self.num_nodes),
self.dataset_cache_file,
)
self.data, self.slices = base_data, slices
Expand Down
44 changes: 33 additions & 11 deletions ml/gnn/src/gnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import os
import random
from rich.console import Console
from rich.live import Live
from rich.layout import Layout
Expand All @@ -10,7 +12,13 @@
from model.gat import GATModel
from model.gcn import GCNModel

torch.manual_seed(42)
seed = sys.argv[4]

if seed is None:
exit("Please provide the random seed as an argument")

torch.manual_seed(seed)
random.seed(seed)

train_dataset_file = sys.argv[1]
validation_dataset_file = sys.argv[2]
Expand Down Expand Up @@ -102,13 +110,15 @@
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
).fit(
)
gat.fit(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
num_epochs=num_epochs,
start_epoch=start_epoch,
patience=patience,
).evaluate(
)
gat_report = gat.evaluate(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
Expand All @@ -118,27 +128,39 @@
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
).fit(
)
gcn.fit(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
num_epochs=num_epochs,
start_epoch=start_epoch,
patience=patience,
).evaluate(
)
gcn_report = gcn.evaluate(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)

output_dir = f"{script_dir}/../../.output/gnn/{seed}"

def save_report(model_name, report):
report_dir = f"{output_dir}/{model_name}"
os.makedirs(report_dir, exist_ok=True)
for name, value in report.items():
serialized = json.dumps(value, indent=4)
with open(f"{report_dir}/{name}.json", "w") as file:
file.write(serialized)


console = Console()
with console.capture() as capture:
console.print(layout)
output = capture.get()
output_dir = f"{script_dir}/../../.output"
output_file = "gnn"
if "NAME" in os.environ:
output_file = f"{output_file}_{os.environ['NAME']}"
output_path = f"{output_dir}/{output_file}.log"
with open(output_path, "w") as file:
# ensure out dir exists
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/log.txt", "w") as file:
file.write(output)
save_report("gat", gat_report)
save_report("gcn", gcn_report)

13 changes: 9 additions & 4 deletions ml/gnn/src/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataset import CM2MLDataset
from layout_proxy import LayoutProxy
from utils import device, pretty_duration, script_dir, text_padding
from sklearn.metrics import classification_report


def accuracy(logits: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, int]:
Expand Down Expand Up @@ -163,8 +164,12 @@ def evaluate_dataset(self, dataset: CM2MLDataset) -> None:
total_top_n_prediction_count = 0
total_weighted_correct_predictions = 0
total_weighted_prediction_count = 0
preds = []
labels = []
for data in dataset:
out = self.forward(data)
preds.extend(out.argmax(dim=1).cpu().numpy())
labels.extend(data.y.cpu().numpy())
(
correct_predictions,
prediction_count,
Expand Down Expand Up @@ -196,6 +201,9 @@ def evaluate_dataset(self, dataset: CM2MLDataset) -> None:
self.layout_proxy.print(
f"{text_padding}{dataset.name}: Acc: {total_accuracy:.2%}, Pred: {total_correct_predictions:.0f}/{total_prediction_count}, Acc@{dataset.top_n}: {total_top_n_accuracy:.2%}, Pred@{dataset.top_n}: {total_top_n_correct_predictions:.0f}/{total_top_n_prediction_count}, Wgth: {total_weighted_accuracy:.2%}"
)
report = classification_report(labels, preds, output_dict=True)
return report


def evaluate(
self,
Expand All @@ -204,7 +212,4 @@ def evaluate(
test_dataset: CM2MLDataset,
):
self.layout_proxy.print("Evaluating...")
self.evaluate_dataset(train_dataset)
self.evaluate_dataset(validation_dataset)
self.evaluate_dataset(test_dataset)
return self
return { "train": self.evaluate_dataset(train_dataset), "validation": self.evaluate_dataset(validation_dataset), "test": self.evaluate_dataset(test_dataset) }
15 changes: 14 additions & 1 deletion ml/scripts/train-gnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ rm -f gnn/.cache/graph_train.json.dataset
rm -f gnn/.cache/graph_validation.json.dataset
rm -f gnn/.cache/graph_test.json.dataset

rm -rf .output/gnn

source scripts/conda-activate.sh

python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 42
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 43
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 44
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 45
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 46
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 47
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 48
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 49
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 50
python gnn/src/gnn.py graph_train.json graph_validation.json graph_test.json 51

python gnn/src/calculate_metrics.py

0 comments on commit 38c19f4

Please sign in to comment.