Skip to content

Commit

Permalink
Fix errors and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dongkwan-kim committed Jun 29, 2024
1 parent 9cae252 commit eddd5d5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
11 changes: 6 additions & 5 deletions WL4S/wl4s.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
ModelType = Union[MultiOutputClassifier, LinearSVC, SVC]

DATASETS_REAL = ["PPIBP", "EMUser", "HPOMetab", "HPONeuro"]
DATASETS_SYN = ["Component", "Density", "Coreness", "CutRatio"]
DATASETS_SYN = ["CutRatio", "Density", "Coreness", "Component"]
MODEL_KWARGS_KEY = ["C", "kernel", "dual"]

parser = argparse.ArgumentParser()
Expand All @@ -47,9 +47,10 @@


class WL4S(torch.nn.Module):
def __init__(self, stype, num_layers, norm,
def __init__(self, dataset_name, stype, num_layers, norm,
dtype="histogram", splits=None, k_to_sample=None, precompute=False):
super(WL4S, self).__init__()
self.dataset_name = dataset_name
self.stype = stype
self.norm = norm
self.dtype = dtype
Expand Down Expand Up @@ -77,7 +78,7 @@ def forward(self, x, edge_index, batch_or_sub_batch, x_to_xs=None, mask=None):
if self.dtype == "kernel":
kernel_key = kk(self.k_to_sample, self.stype, self.norm, i)
if self.splits[-1] <= 250: # synthetic graphs, backward compatibility
kernel_key = kk(self.k_to_sample, self.stype[:3], self.norm, i, edge_index.size(1) // 100)
kernel_key = kk(self.k_to_sample, self.stype[:3], self.norm, i, self.dataset_name)

h = hist_linear_kernels(hist=h, splits=self.splits, key=kernel_key)
if self.precompute:
Expand All @@ -89,7 +90,7 @@ def forward(self, x, edge_index, batch_or_sub_batch, x_to_xs=None, mask=None):


def kk(k_to_sample, stype, norm, i, *args):
return "_".join([str(s) for s in [k_to_sample, stype, norm, i, *args]])
return "_".join([str(s) for s in [*args, k_to_sample, stype, norm, i]])


@fscaches(path="../_caches", keys_to_exclude=["hist"], verbose=True)
Expand Down Expand Up @@ -157,7 +158,7 @@ def get_data_and_model(args, precompute=False):
f"{splits} != [{len(train_dts), len(val_dts), len(test_dts)}]"
all_data = Batch.from_data_list(train_dts + val_dts + test_dts)

wl = WL4S(stype=args.stype, num_layers=args.wl_layers, norm=args.hist_norm,
wl = WL4S(dataset_name=args.dataset_name, stype=args.stype, num_layers=args.wl_layers, norm=args.hist_norm,
dtype=args.dtype, splits=splits, k_to_sample=args.k_to_sample, precompute=precompute)

if args.stype == "connected":
Expand Down
6 changes: 4 additions & 2 deletions WL4S/wl4s2v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


def get_data_mixed_kernels(args):
assert args.dtype == "kernel"

args.stype = "separated"
k_list_s, splits_s, y_s = get_data_and_model(args)

Expand Down Expand Up @@ -45,8 +47,8 @@ def get_data_mixed_kernels(args):
run_one(__args__)
else:
for _a_c, _a_s in [
(0.99, 0.01), (0.9, 0.1), (0.5, 0.1),
(0.01, 0.99), (0.1, 0.9), (0.1, 0.5),
(0.999, 0.001), (0.99, 0.01), (0.9, 0.1),
(0.001, 0.999), (0.01, 0.99), (0.1, 0.9),
]:
__args__.a_c, __args__.a_s = _a_c, _a_s

Expand Down
23 changes: 10 additions & 13 deletions WL4S/wl4s_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"stype": ["separated"],
"wl_cumcat": [False, True],
"hist_norm": [False, True],
"model": ["SVC"],
"kernel": ["precomputed"],
"dtype": ["kernel"],
}
Cx100 = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576]
MORE_HPARAM_SPACE = {
Expand All @@ -18,26 +21,19 @@

MODE = __args__.MODE
if MODE == "syn_k":
__args__.stype = "connected"
HPARAM_SPACE["stype"] = ["connected"]
kws = dict(file_dir="../_logs_wl4s_k", log_postfix=f"_inf")
hp_search_syn(__args__, HPARAM_SPACE, MORE_HPARAM_SPACE, **kws)

__args__.stype = "separated"
HPARAM_SPACE = {**HPARAM_SPACE, "model": ["LinearSVC"]}
MORE_HPARAM_SPACE = {**MORE_HPARAM_SPACE, "dual": [True, False]}
HPARAM_SPACE["stype"] = ["separated"]
for k_to_sample in [None, 1, 2]:
__args__.k_to_sample = k_to_sample
kws = dict(file_dir="../_logs_wl4s_k", log_postfix=f"_{k_to_sample or 0}")
hp_search_syn(__args__, HPARAM_SPACE, MORE_HPARAM_SPACE, **kws)

__args__.stype = "connected"
HPARAM_SPACE["stype"] = ["connected"]
kws = dict(file_dir="../_logs_wl4s_k", log_postfix=f"_inf")
hp_search_syn(__args__, HPARAM_SPACE, MORE_HPARAM_SPACE, **kws)

else:
HPARAM_SPACE = {
**HPARAM_SPACE,
"model": ["SVC"], "kernel": ["precomputed"], "dtype": ["kernel"],
}
__args__.dtype = "kernel"

if MODE == "real_precomputation":
for k_to_sample in [None, 1, 2]:
for dataset_name in ["PPIBP", "EMUser"]:
Expand All @@ -62,6 +58,7 @@

elif MODE == "real_k":
for dataset_name in ["PPIBP", "EMUser", "HPOMetab", "HPONeuro"]:
# OOM for HPOMetab & HPONeuro with k >= 1
k_to_sample_list = [None] if dataset_name in ["HPOMetab", "HPONeuro"] else [None, 1, 2]
for k_to_sample in k_to_sample_list:
__args__.k_to_sample = k_to_sample
Expand Down

0 comments on commit eddd5d5

Please sign in to comment.