-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhume_dataset_av.py
126 lines (91 loc) · 4.07 KB
/
hume_dataset_av.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import os
import pickle
import numpy as np
# 创建自定义的数据集类
class HumeDataset(Dataset):
def __init__(self, csv_file, root_dir, feature):
self.csv_data = pd.read_csv(csv_file)
self.root_dir = root_dir
self.feature = feature
def __len__(self):
return len(self.csv_data)
def __getitem__(self, idx):
file_name = str(self.csv_data.iloc[idx, 0]).zfill(5) + '.pkl'
#视频特征
fea_resnet18 = self.feature[0]
file_path = os.path.join(self.root_dir, fea_resnet18, file_name)
with open(file_path, 'rb') as f:
data_resnet18 = pickle.load(f).detach().numpy() #shape[t, d]
fea_aus = self.feature[1]
file_path = os.path.join(self.root_dir, fea_aus, file_name)
with open(file_path, 'rb') as f:
data_aus = pickle.load(f).astype(np.float32)
data_v = np.concatenate((data_resnet18, data_aus), axis=1)
t = 300
#时间维度上对齐
if data_v.shape[0] < t:
data_v = np.pad(data_v, ((0, t - data_v.shape[0]), (0, 0)), mode='wrap')
#时间维度上等差的取出t个数据
else:
indices = np.linspace(0, data_v.shape[0] - 1, num=t, dtype=int)
data_v = data_v[indices]
# data = data[:42,]
vid_fea = torch.tensor(data_v, dtype=torch.float32)
# 音频特征
audio_feature = self.feature[2]
audfile_path = os.path.join(self.root_dir, audio_feature, file_name)
with open(audfile_path, 'rb') as f:
aud_data = pickle.load(f)#shape[t, d]
a_t = 300
#时间维度上对齐
if aud_data.shape[0] < a_t:
aud_data = np.pad(aud_data, ((0, a_t - aud_data.shape[0]), (0, 0)), mode='wrap')
#时间维度上等差的取出t个数据
else:
indices = np.linspace(0, aud_data.shape[0] - 1, num=a_t, dtype=int)
aud_data = aud_data[indices]
# data = data[:t,]
aud_fea = torch.tensor(aud_data, dtype=torch.float32)
# 标签
label = self.csv_data.iloc[idx, 1:].to_numpy()
# 将数据和标签转换为 tensor 格式
label_tensor = torch.tensor(label, dtype=torch.float32)
return vid_fea, aud_fea, label_tensor
@staticmethod
def collate_fn(batch):
# 官方实现的default_collate可以参考
# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py
vid_fea, aud_fea, labels = tuple(zip(*batch))
vid_fea = torch.stack(vid_fea, dim=0)
aud_fea = torch.stack(aud_fea, dim=0)
labels = torch.as_tensor(labels)
return vid_fea, aud_fea, labels
if __name__ =="__main__":
from tqdm import tqdm
root = 'D:/Desktop/code/6th-ABAW/dataset/'
feature = ['vit','wav2vec2']
# 实例化训练和验证数据集
train_dataset = HumeDataset(os.path.join(root, 'train_split.csv'),feature=feature, root_dir=root)
val_dataset = HumeDataset(os.path.join(root, 'valid_split.csv'),feature=feature, root_dir=root)
# for vid_fea,aud_fea, labels in tqdm(train_dataset,desc='train'):
# print(vid_fea.shape, aud_fea.shape, labels.shape)
# pass
# print(len(train_dataset))
# print(len(val_dataset))
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, pin_memory=True, num_workers=8)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, pin_memory=True, num_workers=8)
# 使用数据加载器进行训练和验证
for vid_fea, aud_fea, labels in tqdm(val_dataset,desc='valdataloader'):
# 在这里进行训练
pass
# for batch in val_loader:
# data, labels = batch['data'], batch['label']
# # 在这里进行验证
# pass
# first_batch = next(iter(train_loader))
# first_batch = first_batch[1]
# print(first_batch.shape)