-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #206 from gjy1221/heat_branch
add model heat
- Loading branch information
Showing
10 changed files
with
553 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import argparse | ||
import os | ||
os.environ["OMP_NUM_THREADS"] = "4" | ||
os.environ['TL_BACKEND'] = 'torch' | ||
|
||
import tensorlayerx as tlx | ||
from gammagl.datasets import NGSIM_US_101 | ||
from gammagl.models import HEAT | ||
from gammagl.loader import DataLoader | ||
from tensorlayerx.model import TrainOneStep, WithLoss | ||
|
||
|
||
class SemiSpvzLoss(WithLoss): | ||
def __init__(self, net, loss_fun): | ||
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fun) | ||
|
||
def forward(self, data, label): | ||
logits = self._backbone(data.x, data.edge_index, data.edge_attr, data.edge_type) | ||
|
||
train_logits = tlx.gather(logits, data.tar_mask) | ||
train_y = tlx.gather(data.y, data.tar_mask) | ||
|
||
loss = self._loss_fn(train_logits, train_y, reduction='mean') | ||
loss = tlx.sqrt(loss) | ||
# loss_each_data = tlx.sqrt(tlx.losses.mean_squared_error(train_logits, train_y, reduction='mean')) | ||
|
||
return loss | ||
|
||
|
||
def main(args): | ||
# load datasets | ||
train_set = NGSIM_US_101(root=args.data_path, name='train') | ||
val_set = NGSIM_US_101(root=args.data_path, name='val') | ||
test_set = NGSIM_US_101(root=args.data_path, name='test') | ||
|
||
trainDataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) | ||
valDataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=True) | ||
testDataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True) | ||
|
||
net = HEAT(args.hist_length, args.in_channels_node, args.out_channels, args.out_length, | ||
args.in_channels_edge_attr, args.in_channels_edge_type, args.edge_attr_emb_size, | ||
args.edge_type_emb_size, args.node_emb_size, args.heads, args.concat, args.dropout, args.leaky_rate) | ||
|
||
print('loading HEAT model') | ||
|
||
optimizer = tlx.optimizers.Adam(lr=args.lr) | ||
|
||
train_weights = net.trainable_weights | ||
|
||
scheduler = tlx.optimizers.lr.MultiStepDecay(learning_rate=args.lr, milestones=[1, 2, 4, 6, 10, 30, 40, 50, 60], | ||
gamma=0.7, verbose=True) | ||
|
||
loss_fn = SemiSpvzLoss(net, tlx.losses.mean_squared_error) | ||
|
||
train_one_step = TrainOneStep(loss_fn, optimizer, train_weights) | ||
|
||
best_val_loss = 1000 | ||
for epoch in range(args.n_epoch): | ||
train_loss_epo = 0.0 | ||
net.set_train() | ||
for i, data in enumerate(trainDataloader): | ||
indices = tlx.arange(0, args.out_length) | ||
data.y = tlx.gather(data.y, indices, axis=1) | ||
data.y = tlx.reshape(data.y, (data.y.shape[0], -1)) | ||
loss_each_data = train_one_step(data, data.y) | ||
train_loss_epo += loss_each_data | ||
|
||
train_loss_epoch = round(train_loss_epo * 0.3048 / (i + 1), 4) | ||
|
||
val_loss_epoch = 0 | ||
net.set_eval() | ||
for j, data in enumerate(valDataloader): | ||
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type) | ||
indices = tlx.arange(0, args.out_length) | ||
data.y = tlx.gather(data.y, indices, axis=1) | ||
data.y = tlx.reshape(data.y, (data.y.shape[0], -1)) | ||
val_logits = tlx.gather(logits, data.tar_mask) | ||
val_y = tlx.gather(data.y, data.tar_mask) | ||
val_loss_epoch += tlx.convert_to_numpy( | ||
tlx.sqrt(tlx.losses.mean_squared_error(val_logits, val_y, reduction='mean'))) | ||
|
||
val_loss_epoch = round(val_loss_epoch * 0.3048 / (j + 1), 4) | ||
|
||
print("Epoch [{:0>3d}] ".format(epoch + 1) + " train loss: {:.4f}".format( | ||
train_loss_epoch) + " val loss: {:.4f}".format(val_loss_epoch)) | ||
|
||
# save best model on evaluation set | ||
if val_loss_epoch < best_val_loss: | ||
best_val_loss = val_loss_epoch | ||
net.save_weights(str(args.out_length) + '-' + str(best_val_loss) + '.npz', format='npz_dict') | ||
|
||
scheduler.step() | ||
|
||
# Euclidean distance | ||
net.set_eval() | ||
net.load_weights(str(args.out_length) + '-' + str(best_val_loss) + '.npz', format='npz_dict') | ||
total_distance = 0 | ||
total_samples = 0 | ||
|
||
for i, data in enumerate(testDataloader): | ||
logits = net(data.x, data.edge_index, data.edge_attr, data.edge_type) | ||
indices = tlx.arange(0, args.out_length) | ||
data.y = tlx.gather(data.y, indices, axis=1) | ||
data.y = tlx.reshape(data.y, (data.y.shape[0], -1)) | ||
test_logits = tlx.gather(logits, data.tar_mask) | ||
test_y = tlx.gather(data.y, data.tar_mask) | ||
|
||
# Calculate Euclidean distance | ||
distance = tlx.sqrt(tlx.reduce_sum(tlx.square(test_logits - test_y))) | ||
total_distance += distance | ||
total_samples += len(test_logits) | ||
|
||
# print("Euclidean distance for batch {}: {:.4f}".format(i + 1, distance)) | ||
|
||
# Calculate average Euclidean distance | ||
average_distance = total_distance / total_samples | ||
print("Average Euclidean distance: {:.4f}".format(average_distance)) | ||
|
||
|
||
if __name__ == '__main__': | ||
# # Network arguments | ||
parser = argparse.ArgumentParser() | ||
# parser.add_argument("--lr", type=float, default=0.005, help="learnin rate") | ||
parser.add_argument("--n_epoch", type=int, default=40, help="number of epoch") | ||
parser.add_argument("--in_channels_node", type=int, default=64, help="heat_in_channels_node") | ||
parser.add_argument("--in_channels_edge_attr", type=int, default=5, help="heat_in_channels_edge_attr") | ||
parser.add_argument("--in_channels_edge_type", type=int, default=6, help="heat_in_channels_edge_type") | ||
parser.add_argument("--edge_attr_emb_size", type=int, default=64, help="heat_edge_attr_emb_size") | ||
parser.add_argument("--edge_type_emb_size", type=int, default=64, help="heat_edge_type_emb_size") | ||
parser.add_argument("--node_emb_size", type=int, default=64, help="heat_node_emb_size") | ||
parser.add_argument("--out_channels", type=int, default=128, help="heat_out_channels") | ||
|
||
parser.add_argument("--heads", type=int, default=3, help="number of heads") | ||
parser.add_argument('--concat', type=bool, default=True, help='heat_concat') | ||
parser.add_argument("--hist_length", type=int, default=10, help="length of history trajectory") | ||
parser.add_argument("--out_length", type=int, default=30, help="length of future trajectory") | ||
parser.add_argument("--dropout", type=float, default=0.5, help="dropout rate") | ||
parser.add_argument("--leaky_rate", type=float, default=0.1, help="LeakyReLU rate") | ||
|
||
parser.add_argument("--lr", type=float, default=0.001, help="learning rate") | ||
parser.add_argument("--batch_size", type=int, default=20, help="batch") | ||
parser.add_argument("--data_path", type=str, default=r'', help="path to save dataset") | ||
parser.add_argument("--result_path", type=str, default=r'', help="path to save result") | ||
parser.add_argument("--device", type=int, default=0) | ||
|
||
args = parser.parse_args() | ||
|
||
if args.device >= 0: | ||
tlx.set_device("GPU", args.device) | ||
else: | ||
tlx.set_device("CPU") | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
Heterogeneous Edge-Enhanced Graph Attention Network(HEAT) | ||
============ | ||
|
||
- Paper link: [https://arxiv.org/abs/2106.07161](https://arxiv.org/abs/2106.07161) | ||
- Author's code repo (in PyTorch): | ||
[https://github.com/Xiaoyu006/MATP-with-HEAT](https://github.com/Xiaoyu006/MATP-with-HEAT). | ||
|
||
Dataset Statics | ||
------- | ||
|
||
| Dataset | # Number of Graphs | # Type of Nodes | # Type of Edges | | ||
|--------------|--------------------|-----------------|-----------| | ||
| NGSIM US-101 | 1201 | 2 | 2 | | ||
|
||
Refer to [NGSIM US-101 Datasets](https://github.com/gjy1221/NGSIM-US-101). | ||
|
||
Results | ||
------- | ||
|
||
```bash | ||
TL_BACKEND="torch" python heat_trainer.py --data_path ../data --result_path ../result | ||
``` | ||
|
||
| Time(sec) | Paper | Our(torch) | | ||
|-----------|--------|------------| | ||
| 1 | 0.6067 | 0.6940 | | ||
| 2 | 0.8556 | 0.7349 | | ||
| 3 | 1.0469 | 1.0621 | | ||
| 4 | 1.3216 | 1.2739 | | ||
| 5 | 1.8894 | 1.7562 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import os | ||
import os.path as osp | ||
import zipfile | ||
import tensorlayerx as tlx | ||
from gammagl.data import Graph, HeteroGraph, download_url, extract_zip | ||
from gammagl.data import InMemoryDataset | ||
from typing import Callable, Optional, List | ||
|
||
|
||
class NGSIM_US_101(InMemoryDataset): | ||
r""" | ||
The NGSIM US-101 dataset from the "NGSIM: Next Generation Simulation" | ||
<https://ops.fhwa.dot.gov/trafficanalysistools/ngsim.htm>`_ project, | ||
containing detailed vehicle trajectory data from the US-101 highway in | ||
Los Angeles, California. | ||
Parameters | ||
---------- | ||
root: str | ||
Root directory where the dataset should be saved. | ||
name: str, optional | ||
The name of the dataset (:obj:`"train", "val", "test"`). | ||
transform: callable, optional | ||
A function/transform that takes in an | ||
:obj:`gammagl.data.Graph` object and returns a transformed | ||
version. The data object will be transformed before every access. | ||
(default: :obj:`None`) | ||
pre_transform: callable, optional | ||
A function/transform that takes in | ||
an :obj:`gammagl.data.Graph` object and returns a | ||
transformed version. The data object will be transformed before | ||
being saved to disk. (default: :obj:`None`) | ||
force_reload (bool, optional): Whether to re-process the dataset. | ||
(default: :obj:`False`) | ||
""" | ||
|
||
url = 'https://github.com/gjy1221/NGSIM-US-101/raw/main/data' | ||
|
||
def __init__(self, root: str = None, name: str = None, | ||
transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, | ||
force_reload: bool = False): | ||
self.name = osp.join('ngsim', name.lower()) | ||
self.split = name.lower() | ||
super().__init__(root, transform, pre_transform, force_reload=force_reload) | ||
self.data_path = osp.join(self.processed_dir, name) | ||
self.data_names = os.listdir('{}'.format(self.data_path)) | ||
|
||
def __len__(self): | ||
return len(self.data_names) | ||
|
||
def __getitem__(self, index): | ||
data_item = tlx.files.load_npy_to_any(self.data_path, self.data_names[index]) | ||
data_item.edge_attr = data_item.edge_attr.transpose(0, 1) | ||
data_item.edge_type = data_item.edge_type.transpose(0, 1) | ||
return data_item | ||
|
||
@property | ||
def raw_dir(self) -> str: | ||
return osp.join(self.root, 'ngsim', 'raw', self.split) | ||
|
||
@property | ||
def processed_dir(self) -> str: | ||
return osp.join(self.root, 'ngsim', 'processed') | ||
|
||
@property | ||
def raw_file_names(self) -> List[str]: | ||
return [f'{self.split.lower()}.zip'] | ||
|
||
@property | ||
def processed_file_names(self) -> str: | ||
return tlx.BACKEND + '_data.pt' | ||
|
||
def download(self): | ||
# print(self.root) | ||
path = download_url(f'{self.url}/{self.raw_file_names[0]}', self.raw_dir) | ||
with zipfile.ZipFile(path, 'r') as zip_ref: | ||
zip_ref.extractall(self.processed_dir) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.