-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
78 lines (61 loc) · 2.45 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import math
import numpy as np
import pandas as pd
import sys
sys.path.append('./lib/')
from pkl_process import *
from utils import load_graphdata_channel_my, compute_val_loss_sttn, image_to_patches
import tools
from time import time
import shutil
import argparse
import configparser
from tensorboardX import SummaryWriter
import os
from ST_Transformer_new import STTransformer # STTN model with linear layer to get positional embedding
from ST_Transformer_new_sinembedding import STTransformer_sinembedding #STTN model with sin()/cos() to get positional embedding, the same as "Attention is all your need"
from VQ_VAE import VQVAE
#%%
if __name__ == '__main__':
params_path = './Experiment/debug/' ## Path for saving network parameters
print('params_path:', params_path)
# filename = './PEMSD7/V_25_r1_d0_w0_astcgn.npz' ## Data generated by prepareData.py
# num_of_hours, num_of_days, num_of_weeks = 1, 0, 0 ## The same setting as prepareData.py
param_file = 'epoch_248000.params'
### Training Hyparameter
device = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE = device
batch_size = 36
batch_length = 18
learning_rate = 0.01
epochs = 1000000
### Generate Data Loader
# train_loader, train_target_tensor, val_loader, val_target_tensor, test_loader, test_target_tensor, _mean, _std = load_graphdata_channel_my(
# filename, num_of_hours, num_of_days, num_of_weeks, DEVICE, batch_size)
# import ipdb; ipdb.set_trace()
dataset_path = '/media/ytzheng/3EA48EC9A48E835F/CARLA_DATA/Town01_copycat'
eval_episodes = tools.load_episodes(dataset_path)
generator = tools.sample_episodes(
eval_episodes, batch_length
)
eval_dataset = tools.from_generator(generator, batch_size)
### Construct Network
net = VQVAE()
net.to(device)
print(net)
param_path = os.path.join(params_path, param_file)
print("Loading params from: ", param_path)
net.load_state_dict(torch.load(param_path))
### Inference
num_episodes = 1000
for episode in range(num_episodes):
net.eval()
eval_batch = next(eval_dataset)
patch_images = image_to_patches(torch.Tensor(eval_batch['image']).to(device))
a_q, _ = net(patch_images.permute(0, 2, 1, 3))
import ipdb; ipdb.set_trace()