Skip to content

Commit

Permalink
更新路径
Browse files Browse the repository at this point in the history
  • Loading branch information
n1108 committed Oct 31, 2024
1 parent d5e74b9 commit 673e903
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 38 deletions.
24 changes: 11 additions & 13 deletions examples/rohehan/rohehan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from gammagl.models import RoheHAN
from utils import *
import pickle as pkl
from gammagl.utils.convert import edge_index_to_adj_matrix
from gammagl.utils import mask_to_index
from gammagl.utils import edge_index_to_adj_matrix
from gammagl.datasets.acm4rohe import ACM4Rohe

class SemiSpvzLoss(tlx.nn.Module):
Expand Down Expand Up @@ -42,9 +43,8 @@ def evaluate(model, data, labels, mask, loss_func):
def main(args):
# Load ACM raw dataset
dataname = 'acm'
dataset = ACM4Rohe(root = "./")
dataset = ACM4Rohe(root = args.dataset_path)
g = dataset[0]
dataset.download_attack_data_files()
features_dict = {ntype: g[ntype].x for ntype in g.node_types if hasattr(g[ntype], 'x')}
labels = g['paper'].y
train_mask = g['paper'].train_mask
Expand All @@ -55,9 +55,9 @@ def main(args):
num_classes = int(tlx.reduce_max(labels)) + 1

# Get train_idx, val_idx, test_idx from masks
train_idx = np.where(train_mask)[0]
val_idx = np.where(val_mask)[0]
test_idx = np.where(test_mask)[0]
train_idx = mask_to_index(train_mask)
val_idx = mask_to_index(val_mask)
test_idx = mask_to_index(test_mask)

x_dict = features_dict
y = labels
Expand Down Expand Up @@ -123,10 +123,6 @@ def main(args):
"y": y
}

# Ensure the best model path exists
if not os.path.exists(args.best_model_path):
os.makedirs(args.best_model_path)

# Training loop
best_val_acc = 0.0

Expand Down Expand Up @@ -159,7 +155,8 @@ def main(args):
tar_idx = []
# can attack 500 target nodes by seting range(5)
for i in range(1):
with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f:
target_filename = os.path.join(args.dataset_path, f'ACM4Rohe/raw/data/preprocess/target_nodes/acm_r_target{i}.pkl')
with open(target_filename, 'rb') as f:
tar_tmp = np.sort(pkl.load(f))
tar_idx.extend(tar_tmp)

Expand All @@ -173,7 +170,7 @@ def main(args):

# Load adversarial attacks
n_perturbation = 1
adv_filename = f'data/generated_attacks/adv_acm_pap_pa_{n_perturbation}.pkl'
adv_filename = os.path.join(args.dataset_path, 'ACM4Rohe/raw/data/generated_attacks', f'adv_acm_pap_pa_{n_perturbation}.pkl')
with open(adv_filename, 'rb') as f:
modified_opt = pkl.load(f)

Expand Down Expand Up @@ -266,6 +263,7 @@ def main(args):
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay.")
parser.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.")
parser.add_argument("--gpu", type=int, default=0, help="GPU index. Use -1 for CPU.")
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
parser.add_argument("--best_model_path", type=str, default='./', help="Path to save the best model.")
args = parser.parse_args()

Expand All @@ -276,4 +274,4 @@ def main(args):
else:
tlx.set_device("CPU")

main(args)
main(args)
59 changes: 34 additions & 25 deletions gammagl/datasets/acm4rohe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,25 @@

class ACM4Rohe(InMemoryDataset):
r"""The ACM dataset for heterogeneous graph neural networks, consisting of nodes of types
:obj:`"paper"`, :obj:`"author"`, and :obj:`"field"`. This dataset was adapted from
`"Heterogeneous Graph Attention Network" <https://github.com/Jhy1993/HAN>`_,
and is typically used for semi-supervised node classification in
:obj:`"paper"`, :obj:`"author"`, and :obj:`"field"`. This dataset was adapted from
`"Heterogeneous Graph Attention Network" <https://github.com/Jhy1993/HAN>`_,
and is typically used for semi-supervised node classification in
heterogeneous graphs.
Parameters
----------
root: str, optional
Root directory where the dataset should be saved.
transform: callable, optional
A function/transform that takes in an :obj:`HeteroGraph` object
and returns a transformed version. The data object will be transformed before
A function/transform that takes in an :obj:`HeteroGraph` object
and returns a transformed version. The data object will be transformed before
every access. (default: :obj:`None`)
pre_transform: callable, optional
A function/transform that takes in an :obj:`HeteroGraph` object
and returns a transformed version. The data object will be transformed before
A function/transform that takes in an :obj:`HeteroGraph` object
and returns a transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload: bool, optional
Whether to re-process the dataset, even if it has already been processed.
Whether to re-process the dataset, even if it has already been processed.
(default: :obj:`False`)
Attributes
Expand All @@ -45,23 +45,30 @@ def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = N

@property
def raw_file_names(self) -> List[str]:
return ["ACM.mat"]
return [
"ACM.mat",
"data/generated_attacks/adv_acm_pap_pa_1.pkl",
"data/generated_attacks/adv_acm_pap_pa_3.pkl",
"data/generated_attacks/adv_acm_pap_pa_5.pkl",
"data/preprocess/target_nodes/acm_r_target0.pkl",
"data/preprocess/target_nodes/acm_r_target1.pkl",
"data/preprocess/target_nodes/acm_r_target2.pkl",
"data/preprocess/target_nodes/acm_r_target3.pkl",
"data/preprocess/target_nodes/acm_r_target4.pkl"
]

@property
def processed_file_names(self) -> str:
return tlx.BACKEND + "_data.pt"

def download(self):
# Download the main ACM dataset file
download_url(self.url, self.raw_dir)

def download_attack_data_files(self):
r"""Download additional data files required for adversarial attacks on the ACM dataset.
This method checks if files needed for adversarial attack simulations are present in
the data directory. If any files are missing, it downloads them from a predefined URL.
"""
# Download additional adversarial attack data files if missing
base_url = "https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/main/Code/data"

# List of required adversarial files to download
files_to_download = [
"generated_attacks/adv_acm_pap_pa_1.pkl",
"generated_attacks/adv_acm_pap_pa_3.pkl",
Expand All @@ -73,11 +80,14 @@ def download_attack_data_files(self):
"preprocess/target_nodes/acm_r_target4.pkl"
]

# Download each file if not already present in its designated path
for file_path in files_to_download:
file_url = f"{base_url}/{file_path}"
save_folder = os.path.join("data", os.path.dirname(file_path))
save_folder = os.path.join(self.raw_dir, "data", os.path.dirname(file_path))
os.makedirs(save_folder, exist_ok=True) # Ensure save directory exists

if not os.path.exists(os.path.join(save_folder, os.path.basename(file_path))):
save_path = os.path.join(save_folder, os.path.basename(file_path))
if not os.path.exists(save_path):
download_url(file_url, save_folder)

def process(self):
Expand Down Expand Up @@ -154,23 +164,22 @@ def process(self):
graph['field', 'fp', 'paper'].edge_index = edge_index_fp

graph['paper'].y = labels
graph['paper'].train_mask = train_mask
graph['paper'].val_mask = val_mask
graph['paper'].test_mask = test_mask

graph['paper'].train_mask = tlx.convert_to_tensor(train_mask, dtype=tlx.bool)
graph['paper'].val_mask = tlx.convert_to_tensor(val_mask, dtype=tlx.bool)
graph['paper'].test_mask = tlx.convert_to_tensor(test_mask, dtype=tlx.bool)
if self.pre_transform is not None:
graph = self.pre_transform(graph)

self.save_data(self.collate([graph]), self.processed_paths[0])

def get_meta_graph(self, dataname, given_adj_dict, features_dict, labels=None, train_mask=None, val_mask=None, test_mask=None):
r"""Creates a meta-path based `HeteroGraph` for the ACM dataset.
This function constructs a `HeteroGraph` with meta-path based edges
This function constructs a `HeteroGraph` with meta-path based edges
between `paper` nodes, representing the meta-paths:
- Paper -> Author -> Paper (PAP)
- Paper -> Field -> Paper (PFP)
"""
meta_graph = HeteroGraph()
meta_graph['paper'].x = features_dict['paper']
Expand All @@ -184,4 +193,4 @@ def get_meta_graph(self, dataname, given_adj_dict, features_dict, labels=None, t
meta_graph['paper'].val_mask = val_mask
meta_graph['paper'].test_mask = test_mask

return meta_graph
return meta_graph

0 comments on commit 673e903

Please sign in to comment.