-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
54 lines (36 loc) · 1.55 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torchaudio
import torch
from torch.utils.data import Dataset
class SpeechDataset(Dataset):
def __init__(self, noisy_files, clean_files):
super().__init__()
# list of files
self.noisy_files = sorted(noisy_files)
self.clean_files = sorted(clean_files)
# len
self.len = len(self.noisy_files)
self.max_len = 65280
def __len__(self):
return self.len
def load_sample(self, file):
waveform, _ = torchaudio.load(file)
# Convert to mono if necessary
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
return waveform
def __getitem__(self, index):
# load to tensors and normalization
x_clean = self.load_sample(self.clean_files[index])
x_noisy = self.load_sample(self.noisy_files[index])
# padding/cutting
x_clean = self._prepare_sample(x_clean)
x_noisy = self._prepare_sample(x_noisy)
return x_noisy,x_clean
def _prepare_sample(self, waveform):
# Assume waveform is of shape (channels, time)
channels, current_len = waveform.shape
# Initialize output tensor with zeros
output = torch.zeros((channels, self.max_len), dtype=torch.float32, device=waveform.device)
# Copy the necessary part of the data
output[:, -min(current_len, self.max_len):] = waveform[:, :min(current_len, self.max_len)]
return output