-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_io_reader_test.py
78 lines (65 loc) · 2.66 KB
/
batch_io_reader_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import time
import torch
from model_trainer.utils import to_gpu
from tqdm import tqdm
from model_loader.stft_dataloader import SFTFDataloader
from model_trainer.trainer_specs import ExperimentSpecs
def batch_reader(batch, device, version=3):
"""
Batch parser for dtc.
:param version:
:param device:
:param batch:
:return:
"""
if version == 3:
text_padded, input_lengths, mel_padded, gate_padded, output_lengths, stft = batch
text_padded = to_gpu(text_padded, device).long()
input_lengths = to_gpu(input_lengths, device).long()
max_len = torch.max(input_lengths.data).item()
mel_padded = to_gpu(mel_padded, device).float()
gate_padded = to_gpu(gate_padded, device).float()
output_lengths = to_gpu(output_lengths, device).long()
sf = stft.contiguous()
if torch.cuda.is_available():
sf = sf.cuda(non_blocking=True)
sf.requires_grad = False
return (text_padded, input_lengths,
mel_padded, max_len,
output_lengths, stft), \
(mel_padded, gate_padded, stft)
else:
text_padded, input_lengths, mel_padded, gate_padded, output_lengths = batch
text_padded = to_gpu(text_padded, device).long()
input_lengths = to_gpu(input_lengths, device).long()
max_len = torch.max(input_lengths.data).item()
mel_padded = to_gpu(mel_padded, device).float()
gate_padded = to_gpu(gate_padded, device).float()
output_lengths = to_gpu(output_lengths, device).long()
return (text_padded, input_lengths,
mel_padded, max_len,
output_lengths), \
(mel_padded, gate_padded)
def v3_dataloader_audio_test(config="config.yaml"):
"""
:return:
"""
spec = ExperimentSpecs(spec_config=config)
start_time = time.time()
dataloader = SFTFDataloader(spec, verbose=True)
print("--- %s SFTFDataloader create batch , load time, seconds ---" % (time.time() - start_time))
_device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
# get all
data_loaders, collate_fn = dataloader.get_all()
_train_loader = data_loaders['train_set']
iters = dataloader.get_train_dataset_size() // dataloader.get_batch_size()
print("Total iters", iters)
# full GPU pass
start_time = time.time()
for batch_idx, (batch) in tqdm(enumerate(_train_loader), total=iters):
x, y = batch_reader(batch, device=_device, version=3)
print("--- %s SFTFDataloader entire dataset pass, load time, seconds ---" % (time.time() - start_time))
if __name__ == '__main__':
"""
"""
v3_dataloader_audio_test()