Skip to content

Commit

Permalink
feat(tree-lstm): calculate metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
DerYeger committed Oct 13, 2024
1 parent 38c19f4 commit 84262d4
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 43 deletions.
32 changes: 14 additions & 18 deletions ml/gnn/src/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@

num_node_features = train_dataset.num_features
num_edge_features = train_dataset.num_edge_features
hidden_channels = max_num_classes * 2
out_channels = max_num_classes

if (
train_dataset.num_features != validation_dataset.num_features
Expand All @@ -87,14 +85,14 @@
gat = GATModel(
num_node_features=num_node_features,
num_edge_features=num_edge_features,
hidden_channels=hidden_channels,
out_channels=out_channels,
hidden_channels=max_num_classes * 2,
out_channels=max_num_classes,
layout=layout["models"]["gat"],
)
gcn = GCNModel(
num_node_features=num_node_features,
hidden_channels=hidden_channels,
out_channels=out_channels,
hidden_channels=max_num_classes,
out_channels=max_num_classes,
layout=layout["models"]["gcn"],
)

Expand All @@ -106,11 +104,11 @@
validation_dataset.print_and_calculate_label_metrics(max_num_classes)
test_dataset.print_and_calculate_label_metrics(max_num_classes)

gat.evaluate(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
# gat.evaluate(
# train_dataset=train_dataset,
# validation_dataset=validation_dataset,
# test_dataset=test_dataset,
# )
gat.fit(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
Expand All @@ -124,11 +122,11 @@
test_dataset=test_dataset,
)

gcn.evaluate(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)
# gcn.evaluate(
# train_dataset=train_dataset,
# validation_dataset=validation_dataset,
# test_dataset=test_dataset,
# )
gcn.fit(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
Expand All @@ -152,7 +150,6 @@ def save_report(model_name, report):
with open(f"{report_dir}/{name}.json", "w") as file:
file.write(serialized)


console = Console()
with console.capture() as capture:
console.print(layout)
Expand All @@ -163,4 +160,3 @@ def save_report(model_name, report):
file.write(output)
save_report("gat", gat_report)
save_report("gcn", gcn_report)

7 changes: 5 additions & 2 deletions ml/gnn/src/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,15 @@ def evaluate_dataset(self, dataset: CM2MLDataset) -> None:
report = classification_report(labels, preds, output_dict=True)
return report


def evaluate(
self,
train_dataset: CM2MLDataset,
validation_dataset: CM2MLDataset,
test_dataset: CM2MLDataset,
):
self.layout_proxy.print("Evaluating...")
return { "train": self.evaluate_dataset(train_dataset), "validation": self.evaluate_dataset(validation_dataset), "test": self.evaluate_dataset(test_dataset) }
return {
"train": self.evaluate_dataset(train_dataset),
"validation": self.evaluate_dataset(validation_dataset),
"test": self.evaluate_dataset(test_dataset),
}
8 changes: 4 additions & 4 deletions ml/scripts/encode-graph.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
train_size=1600
validation_size=400
test_size=2000
train_size=600
validation_size=200
test_size=200

train_start=1
validation_start=$(($train_size + $train_start))
test_start=$(($validation_start + $validation_size + 1))

input=../models/uml/dataset

parameters=("--strict" "--deduplicate" "--continue-on-error" "--relationships-as-edges" "--raw-strings" "--only-encoded-features" "--edge-tag-as-attribute")
parameters=("--strict" "--deduplicate" "--continue-on-error" "--relationships-as-edges" "true" "--raw-strings" "--only-encoded-features" "--edge-tag-as-attribute" "true" "--only-containment-associations" "false")

time bun node_modules/@cm2ml/cli/bin/cm2ml.mjs batch-uml-raw-graph "$input" --start "$train_start" --limit "$train_size" --out .input/graph_train.json "${parameters[@]}" && \
time bun node_modules/@cm2ml/cli/bin/cm2ml.mjs batch-uml-raw-graph "$input" --start "$validation_start" --limit "$validation_size" --out .input/graph_validation.json "${parameters[@]}" --node-features ".input/graph_train.json" --edge-features ".input/graph_train.json" --deduplication-data ".input/graph_train.json" && \
Expand Down
8 changes: 4 additions & 4 deletions ml/scripts/encode-tree.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
train_size=200
validation_size=30
test_size=30
train_size=600
validation_size=200
test_size=200

train_start=1
validation_start=$(($train_size + $train_start))
test_start=$(($validation_start + $validation_size + 1))

input=../models/uml/dataset

parameters=("--strict" "--deduplicate" "--format" "local" "--continue-on-error" "--relationships-as-edges" "--pretty" "--only-containment-associations" "--replace-node-ids" "--raw-strings" "--only-encoded-features")
parameters=("--strict" "--deduplicate" "--format" "local" "--continue-on-error" "--relationships-as-edges" "true" "--pretty" "--only-containment-associations" "true" "--replace-node-ids" "--raw-strings" "--only-encoded-features")

time bun node_modules/@cm2ml/cli/bin/cm2ml.mjs batch-uml-tree "$input" --start "$train_start" --limit "$train_size" --out .input/tree_train.json "${parameters[@]}" && \
time bun node_modules/@cm2ml/cli/bin/cm2ml.mjs batch-uml-tree "$input" --start "$validation_start" --limit "$validation_size" --out .input/tree_validation.json "${parameters[@]}" --node-features ".input/tree_train.json" --edge-features ".input/tree_train.json" --deduplication-data ".input/tree_train.json" && \
Expand Down
26 changes: 21 additions & 5 deletions ml/tree-lstm/src/paper/mdeoperation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import sys
import time
from typing import Union
import numpy as np
from sklearn.metrics import classification_report
import torch

from torch import cuda
Expand All @@ -18,6 +20,7 @@
from tree_dataset import TreeDataset
from utils import script_dir


def create_model(
source_vocab: data_utils.Vocab,
target_vocab: data_utils.Vocab,
Expand Down Expand Up @@ -101,7 +104,7 @@ def step_tree2tree(

def evaluate(
model: network.Tree2TreeModel,
test_dataset: data_utils.EncodedDataset, # TODO/Jan: Type!
test_dataset: data_utils.EncodedDataset,
source_vocab: data_utils.Vocab,
target_vocab: data_utils.Vocab,
):
Expand All @@ -112,6 +115,8 @@ def evaluate(
tot_trees = len(test_dataset)
res = []
# model.eval()
preds = []
labels = []

for idx in range(0, len(test_dataset), args.batch_size):
encoder_inputs, decoder_inputs = model.get_batch(test_dataset, start_idx=idx)
Expand Down Expand Up @@ -158,6 +163,18 @@ def evaluate(
# print(f"output {current_output_print}")
# print("---")

labels.extend(current_target)
# extends preds with current_output, but ensure length matches to len(current_target)
shortened_output = current_output[: len(current_target)]
padded_output = shortened_output + [0] * (
len(current_target) - len(shortened_output)
)
preds.extend(padded_output)

if len(labels) == 1:
print(f"labels: {labels}")
print(f"preds: {preds}")

tot_tokens += len(current_target)
all_correct = len(current_target) == len(current_output)
wrong_tokens = 0
Expand All @@ -178,6 +195,7 @@ def evaluate(
print(" test: accuracy of tokens %.2f" % (acc_tokens * 1.0 / tot_tokens))
print(" test: accuracy of programs %.2f" % (acc_trees * 1.0 / tot_trees))
print(acc_tokens, tot_tokens, acc_trees, tot_trees)
print(classification_report(labels, preds, zero_division=np.nan))


def train(
Expand Down Expand Up @@ -240,7 +258,7 @@ def train(
model, encoder_inputs, decoder_inputs, feed_previous=False
)

epoch_time += (time.time() - start_time)
epoch_time += time.time() - start_time
loss += batch_loss
current_step += 1

Expand Down Expand Up @@ -281,9 +299,7 @@ def train(
else:
remaining_patience -= 1
if remaining_patience == 0:
print(
f"Early stopping in epoch {epoch}"
)
print(f"Early stopping in epoch {epoch}")
break
sys.stdout.flush()
time_training = datetime.datetime.now() - start_datetime
Expand Down
10 changes: 0 additions & 10 deletions ml/tree-lstm/src/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import os
import time
import torch
from rich.align import Align
from rich.panel import Panel
from rich.spinner import Spinner

text_padding = " " * 2

Expand All @@ -19,13 +16,6 @@ def pretty_duration(duration_seconds: int) -> str:
device = torch.device("mps" if use_mps else "cpu")


def WaitingSpinner(title: str):
return Panel(
Align.center(Spinner("dots", text="Waiting..."), vertical="middle"),
title=title,
)


def merge_vocabularies(vocabularies: list[list[str]]) -> list[str]:
total_vocabulary = []
for vocabulary in vocabularies:
Expand Down

0 comments on commit 84262d4

Please sign in to comment.