Skip to content

Commit

Permalink
Merge pull request #206 from gjy1221/heat_branch
Browse files Browse the repository at this point in the history
add model heat
  • Loading branch information
gyzhou2000 authored Jun 11, 2024
2 parents 9f86d5e + d741516 commit 6f77579
Show file tree
Hide file tree
Showing 10 changed files with 553 additions and 1 deletion.
153 changes: 153 additions & 0 deletions examples/heat/heat_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ['TL_BACKEND'] = 'torch'

import tensorlayerx as tlx
from gammagl.datasets import NGSIM_US_101
from gammagl.models import HEAT
from gammagl.loader import DataLoader
from tensorlayerx.model import TrainOneStep, WithLoss


class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fun):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fun)

def forward(self, data, label):
logits = self._backbone(data.x, data.edge_index, data.edge_attr, data.edge_type)

train_logits = tlx.gather(logits, data.tar_mask)
train_y = tlx.gather(data.y, data.tar_mask)

loss = self._loss_fn(train_logits, train_y, reduction='mean')
loss = tlx.sqrt(loss)
# loss_each_data = tlx.sqrt(tlx.losses.mean_squared_error(train_logits, train_y, reduction='mean'))

return loss


def main(args):
# load datasets
train_set = NGSIM_US_101(root=args.data_path, name='train')
val_set = NGSIM_US_101(root=args.data_path, name='val')
test_set = NGSIM_US_101(root=args.data_path, name='test')

trainDataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
valDataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True)
testDataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True)

net = HEAT(args.hist_length, args.in_channels_node, args.out_channels, args.out_length,
args.in_channels_edge_attr, args.in_channels_edge_type, args.edge_attr_emb_size,
args.edge_type_emb_size, args.node_emb_size, args.heads, args.concat, args.dropout, args.leaky_rate)

print('loading HEAT model')

optimizer = tlx.optimizers.Adam(lr=args.lr)

train_weights = net.trainable_weights

scheduler = tlx.optimizers.lr.MultiStepDecay(learning_rate=args.lr, milestones=[1, 2, 4, 6, 10, 30, 40, 50, 60],
gamma=0.7, verbose=True)

loss_fn = SemiSpvzLoss(net, tlx.losses.mean_squared_error)

train_one_step = TrainOneStep(loss_fn, optimizer, train_weights)

best_val_loss = 1000
for epoch in range(args.n_epoch):
train_loss_epo = 0.0
net.set_train()
for i, data in enumerate(trainDataloader):
indices = tlx.arange(0, args.out_length)
data.y = tlx.gather(data.y, indices, axis=1)
data.y = tlx.reshape(data.y, (data.y.shape[0], -1))
loss_each_data = train_one_step(data, data.y)
train_loss_epo += loss_each_data

train_loss_epoch = round(train_loss_epo * 0.3048 / (i + 1), 4)

val_loss_epoch = 0
net.set_eval()
for j, data in enumerate(valDataloader):
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type)
indices = tlx.arange(0, args.out_length)
data.y = tlx.gather(data.y, indices, axis=1)
data.y = tlx.reshape(data.y, (data.y.shape[0], -1))
val_logits = tlx.gather(logits, data.tar_mask)
val_y = tlx.gather(data.y, data.tar_mask)
val_loss_epoch += tlx.convert_to_numpy(
tlx.sqrt(tlx.losses.mean_squared_error(val_logits, val_y, reduction='mean')))

val_loss_epoch = round(val_loss_epoch * 0.3048 / (j + 1), 4)

print("Epoch [{:0>3d}] ".format(epoch + 1) + " train loss: {:.4f}".format(
train_loss_epoch) + " val loss: {:.4f}".format(val_loss_epoch))

# save best model on evaluation set
if val_loss_epoch < best_val_loss:
best_val_loss = val_loss_epoch
net.save_weights(str(args.out_length) + '-' + str(best_val_loss) + '.npz', format='npz_dict')

scheduler.step()

# Euclidean distance
net.set_eval()
net.load_weights(str(args.out_length) + '-' + str(best_val_loss) + '.npz', format='npz_dict')
total_distance = 0
total_samples = 0

for i, data in enumerate(testDataloader):
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type)
indices = tlx.arange(0, args.out_length)
data.y = tlx.gather(data.y, indices, axis=1)
data.y = tlx.reshape(data.y, (data.y.shape[0], -1))
test_logits = tlx.gather(logits, data.tar_mask)
test_y = tlx.gather(data.y, data.tar_mask)

# Calculate Euclidean distance
distance = tlx.sqrt(tlx.reduce_sum(tlx.square(test_logits - test_y)))
total_distance += distance
total_samples += len(test_logits)

# print("Euclidean distance for batch {}: {:.4f}".format(i + 1, distance))

# Calculate average Euclidean distance
average_distance = total_distance / total_samples
print("Average Euclidean distance: {:.4f}".format(average_distance))


if __name__ == '__main__':
# # Network arguments
parser = argparse.ArgumentParser()
# parser.add_argument("--lr", type=float, default=0.005, help="learnin rate")
parser.add_argument("--n_epoch", type=int, default=40, help="number of epoch")
parser.add_argument("--in_channels_node", type=int, default=64, help="heat_in_channels_node")
parser.add_argument("--in_channels_edge_attr", type=int, default=5, help="heat_in_channels_edge_attr")
parser.add_argument("--in_channels_edge_type", type=int, default=6, help="heat_in_channels_edge_type")
parser.add_argument("--edge_attr_emb_size", type=int, default=64, help="heat_edge_attr_emb_size")
parser.add_argument("--edge_type_emb_size", type=int, default=64, help="heat_edge_type_emb_size")
parser.add_argument("--node_emb_size", type=int, default=64, help="heat_node_emb_size")
parser.add_argument("--out_channels", type=int, default=128, help="heat_out_channels")

parser.add_argument("--heads", type=int, default=3, help="number of heads")
parser.add_argument('--concat', type=bool, default=True, help='heat_concat')
parser.add_argument("--hist_length", type=int, default=10, help="length of history trajectory")
parser.add_argument("--out_length", type=int, default=30, help="length of future trajectory")
parser.add_argument("--dropout", type=float, default=0.5, help="dropout rate")
parser.add_argument("--leaky_rate", type=float, default=0.1, help="LeakyReLU rate")

parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--batch_size", type=int, default=20, help="batch")
parser.add_argument("--data_path", type=str, default=r'', help="path to save dataset")
parser.add_argument("--result_path", type=str, default=r'', help="path to save result")
parser.add_argument("--device", type=int, default=0)

args = parser.parse_args()

if args.device >= 0:
tlx.set_device("GPU", args.device)
else:
tlx.set_device("CPU")

main(args)
30 changes: 30 additions & 0 deletions examples/heat/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
Heterogeneous Edge-Enhanced Graph Attention Network(HEAT)
============

- Paper link: [https://arxiv.org/abs/2106.07161](https://arxiv.org/abs/2106.07161)
- Author's code repo (in PyTorch):
[https://github.com/Xiaoyu006/MATP-with-HEAT](https://github.com/Xiaoyu006/MATP-with-HEAT).

Dataset Statics
-------

| Dataset | # Number of Graphs | # Type of Nodes | # Type of Edges |
|--------------|--------------------|-----------------|-----------|
| NGSIM US-101 | 1201 | 2 | 2 |

Refer to [NGSIM US-101 Datasets](https://github.com/gjy1221/NGSIM-US-101).

Results
-------

```bash
TL_BACKEND="torch" python heat_trainer.py --data_path ../data --result_path ../result
```

| Time(sec) | Paper | Our(torch) |
|-----------|--------|------------|
| 1 | 0.6067 | 0.6940 |
| 2 | 0.8556 | 0.7349 |
| 3 | 1.0469 | 1.0621 |
| 4 | 1.3216 | 1.2739 |
| 5 | 1.8894 | 1.7562 |
2 changes: 2 additions & 0 deletions gammagl/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .amazon import Amazon
from .coauthor import Coauthor
from .ngsim import NGSIM_US_101
from .tu_dataset import TUDataset
from .planetoid import Planetoid
from .reddit import Reddit
Expand Down Expand Up @@ -42,6 +43,7 @@
'PolBlogs',
'WikiCS',
'MoleculeNet',
'NGSIM_US_101',
'Yelp'
]

Expand Down
79 changes: 79 additions & 0 deletions gammagl/datasets/ngsim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import os.path as osp
import zipfile
import tensorlayerx as tlx
from gammagl.data import Graph, HeteroGraph, download_url, extract_zip
from gammagl.data import InMemoryDataset
from typing import Callable, Optional, List


class NGSIM_US_101(InMemoryDataset):
r"""
The NGSIM US-101 dataset from the "NGSIM: Next Generation Simulation"
<https://ops.fhwa.dot.gov/trafficanalysistools/ngsim.htm>`_ project,
containing detailed vehicle trajectory data from the US-101 highway in
Los Angeles, California.
Parameters
----------
root: str
Root directory where the dataset should be saved.
name: str, optional
The name of the dataset (:obj:`"train", "val", "test"`).
transform: callable, optional
A function/transform that takes in an
:obj:`gammagl.data.Graph` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform: callable, optional
A function/transform that takes in
an :obj:`gammagl.data.Graph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""

url = 'https://github.com/gjy1221/NGSIM-US-101/raw/main/data'

def __init__(self, root: str = None, name: str = None,
transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None,
force_reload: bool = False):
self.name = osp.join('ngsim', name.lower())
self.split = name.lower()
super().__init__(root, transform, pre_transform, force_reload=force_reload)
self.data_path = osp.join(self.processed_dir, name)
self.data_names = os.listdir('{}'.format(self.data_path))

def __len__(self):
return len(self.data_names)

def __getitem__(self, index):
data_item = tlx.files.load_npy_to_any(self.data_path, self.data_names[index])
data_item.edge_attr = data_item.edge_attr.transpose(0, 1)
data_item.edge_type = data_item.edge_type.transpose(0, 1)
return data_item

@property
def raw_dir(self) -> str:
return osp.join(self.root, 'ngsim', 'raw', self.split)

@property
def processed_dir(self) -> str:
return osp.join(self.root, 'ngsim', 'processed')

@property
def raw_file_names(self) -> List[str]:
return [f'{self.split.lower()}.zip']

@property
def processed_file_names(self) -> str:
return tlx.BACKEND + '_data.pt'

def download(self):
# print(self.root)
path = download_url(f'{self.url}/{self.raw_file_names[0]}', self.raw_dir)
with zipfile.ZipFile(path, 'r') as zip_ref:
zip_ref.extractall(self.processed_dir)

4 changes: 3 additions & 1 deletion gammagl/layers/conv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .heat_conv import HEATlayer
from .message_passing import MessagePassing
from .gcn_conv import GCNConv
from .gat_conv import GATConv
Expand Down Expand Up @@ -66,7 +67,8 @@
'MGNNI_m_iter',
'MAGCLConv',
'FusedGATConv',
'Hid_conv'
'Hid_conv',
'HEATlayer'
]

classes = __all__
Loading

0 comments on commit 6f77579

Please sign in to comment.