diff --git a/examples/rohehan/rohehan_trainer.py b/examples/rohehan/rohehan_trainer.py index 903f1029..d4981782 100644 --- a/examples/rohehan/rohehan_trainer.py +++ b/examples/rohehan/rohehan_trainer.py @@ -10,7 +10,8 @@ from gammagl.models import RoheHAN from utils import * import pickle as pkl -from gammagl.utils.convert import edge_index_to_adj_matrix +from gammagl.utils import mask_to_index +from gammagl.utils import edge_index_to_adj_matrix from gammagl.datasets.acm4rohe import ACM4Rohe class SemiSpvzLoss(tlx.nn.Module): @@ -42,9 +43,8 @@ def evaluate(model, data, labels, mask, loss_func): def main(args): # Load ACM raw dataset dataname = 'acm' - dataset = ACM4Rohe(root = "./") + dataset = ACM4Rohe(root = args.dataset_path) g = dataset[0] - dataset.download_attack_data_files() features_dict = {ntype: g[ntype].x for ntype in g.node_types if hasattr(g[ntype], 'x')} labels = g['paper'].y train_mask = g['paper'].train_mask @@ -55,9 +55,9 @@ def main(args): num_classes = int(tlx.reduce_max(labels)) + 1 # Get train_idx, val_idx, test_idx from masks - train_idx = np.where(train_mask)[0] - val_idx = np.where(val_mask)[0] - test_idx = np.where(test_mask)[0] + train_idx = mask_to_index(train_mask) + val_idx = mask_to_index(val_mask) + test_idx = mask_to_index(test_mask) x_dict = features_dict y = labels @@ -123,10 +123,6 @@ def main(args): "y": y } - # Ensure the best model path exists - if not os.path.exists(args.best_model_path): - os.makedirs(args.best_model_path) - # Training loop best_val_acc = 0.0 @@ -159,7 +155,8 @@ def main(args): tar_idx = [] # can attack 500 target nodes by seting range(5) for i in range(1): - with open(f'data/preprocess/target_nodes/{dataname}_r_target{i}.pkl', 'rb') as f: + target_filename = os.path.join(args.dataset_path, f'ACM4Rohe/raw/data/preprocess/target_nodes/acm_r_target{i}.pkl') + with open(target_filename, 'rb') as f: tar_tmp = np.sort(pkl.load(f)) tar_idx.extend(tar_tmp) @@ -173,7 +170,7 @@ def main(args): # Load adversarial attacks n_perturbation = 1 - adv_filename = f'data/generated_attacks/adv_acm_pap_pa_{n_perturbation}.pkl' + adv_filename = os.path.join(args.dataset_path, 'ACM4Rohe/raw/data/generated_attacks', f'adv_acm_pap_pa_{n_perturbation}.pkl') with open(adv_filename, 'rb') as f: modified_opt = pkl.load(f) @@ -266,6 +263,7 @@ def main(args): parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay.") parser.add_argument("--num_epochs", type=int, default=100, help="Number of training epochs.") parser.add_argument("--gpu", type=int, default=0, help="GPU index. Use -1 for CPU.") + parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset") parser.add_argument("--best_model_path", type=str, default='./', help="Path to save the best model.") args = parser.parse_args() @@ -276,4 +274,4 @@ def main(args): else: tlx.set_device("CPU") - main(args) + main(args) \ No newline at end of file diff --git a/gammagl/datasets/acm4rohe.py b/gammagl/datasets/acm4rohe.py index 505f375f..0c84f0d8 100644 --- a/gammagl/datasets/acm4rohe.py +++ b/gammagl/datasets/acm4rohe.py @@ -9,9 +9,9 @@ class ACM4Rohe(InMemoryDataset): r"""The ACM dataset for heterogeneous graph neural networks, consisting of nodes of types - :obj:`"paper"`, :obj:`"author"`, and :obj:`"field"`. This dataset was adapted from - `"Heterogeneous Graph Attention Network" `_, - and is typically used for semi-supervised node classification in + :obj:`"paper"`, :obj:`"author"`, and :obj:`"field"`. This dataset was adapted from + `"Heterogeneous Graph Attention Network" `_, + and is typically used for semi-supervised node classification in heterogeneous graphs. Parameters @@ -19,15 +19,15 @@ class ACM4Rohe(InMemoryDataset): root: str, optional Root directory where the dataset should be saved. transform: callable, optional - A function/transform that takes in an :obj:`HeteroGraph` object - and returns a transformed version. The data object will be transformed before + A function/transform that takes in an :obj:`HeteroGraph` object + and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform: callable, optional - A function/transform that takes in an :obj:`HeteroGraph` object - and returns a transformed version. The data object will be transformed before + A function/transform that takes in an :obj:`HeteroGraph` object + and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) force_reload: bool, optional - Whether to re-process the dataset, even if it has already been processed. + Whether to re-process the dataset, even if it has already been processed. (default: :obj:`False`) Attributes @@ -45,23 +45,30 @@ def __init__(self, root: Optional[str] = None, transform: Optional[Callable] = N @property def raw_file_names(self) -> List[str]: - return ["ACM.mat"] + return [ + "ACM.mat", + "data/generated_attacks/adv_acm_pap_pa_1.pkl", + "data/generated_attacks/adv_acm_pap_pa_3.pkl", + "data/generated_attacks/adv_acm_pap_pa_5.pkl", + "data/preprocess/target_nodes/acm_r_target0.pkl", + "data/preprocess/target_nodes/acm_r_target1.pkl", + "data/preprocess/target_nodes/acm_r_target2.pkl", + "data/preprocess/target_nodes/acm_r_target3.pkl", + "data/preprocess/target_nodes/acm_r_target4.pkl" + ] @property def processed_file_names(self) -> str: return tlx.BACKEND + "_data.pt" def download(self): + # Download the main ACM dataset file download_url(self.url, self.raw_dir) - def download_attack_data_files(self): - r"""Download additional data files required for adversarial attacks on the ACM dataset. - - This method checks if files needed for adversarial attack simulations are present in - the data directory. If any files are missing, it downloads them from a predefined URL. - """ + # Download additional adversarial attack data files if missing base_url = "https://raw.githubusercontent.com/BUPT-GAMMA/RoHe/main/Code/data" + # List of required adversarial files to download files_to_download = [ "generated_attacks/adv_acm_pap_pa_1.pkl", "generated_attacks/adv_acm_pap_pa_3.pkl", @@ -73,11 +80,14 @@ def download_attack_data_files(self): "preprocess/target_nodes/acm_r_target4.pkl" ] + # Download each file if not already present in its designated path for file_path in files_to_download: file_url = f"{base_url}/{file_path}" - save_folder = os.path.join("data", os.path.dirname(file_path)) + save_folder = os.path.join(self.raw_dir, "data", os.path.dirname(file_path)) + os.makedirs(save_folder, exist_ok=True) # Ensure save directory exists - if not os.path.exists(os.path.join(save_folder, os.path.basename(file_path))): + save_path = os.path.join(save_folder, os.path.basename(file_path)) + if not os.path.exists(save_path): download_url(file_url, save_folder) def process(self): @@ -154,10 +164,9 @@ def process(self): graph['field', 'fp', 'paper'].edge_index = edge_index_fp graph['paper'].y = labels - graph['paper'].train_mask = train_mask - graph['paper'].val_mask = val_mask - graph['paper'].test_mask = test_mask - + graph['paper'].train_mask = tlx.convert_to_tensor(train_mask, dtype=tlx.bool) + graph['paper'].val_mask = tlx.convert_to_tensor(val_mask, dtype=tlx.bool) + graph['paper'].test_mask = tlx.convert_to_tensor(test_mask, dtype=tlx.bool) if self.pre_transform is not None: graph = self.pre_transform(graph) @@ -165,12 +174,12 @@ def process(self): def get_meta_graph(self, dataname, given_adj_dict, features_dict, labels=None, train_mask=None, val_mask=None, test_mask=None): r"""Creates a meta-path based `HeteroGraph` for the ACM dataset. - - This function constructs a `HeteroGraph` with meta-path based edges + + This function constructs a `HeteroGraph` with meta-path based edges between `paper` nodes, representing the meta-paths: - Paper -> Author -> Paper (PAP) - Paper -> Field -> Paper (PFP) - + """ meta_graph = HeteroGraph() meta_graph['paper'].x = features_dict['paper'] @@ -184,4 +193,4 @@ def get_meta_graph(self, dataname, given_adj_dict, features_dict, labels=None, t meta_graph['paper'].val_mask = val_mask meta_graph['paper'].test_mask = test_mask - return meta_graph + return meta_graph \ No newline at end of file