diff --git a/.gitignore b/.gitignore index 77f8765..722decd 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,4 @@ *~ .vscode/ +venv/ diff --git a/README.md b/README.md index 1e6d195..0b65277 100644 --- a/README.md +++ b/README.md @@ -38,11 +38,11 @@ python inference.py --input path/to/an/audio/file --gpu 0 python inference.py --input path/to/an/audio/file --tta --gpu 0 ``` - +``` ## Train your own model @@ -61,12 +61,13 @@ path/to/dataset/ ### Train a model ``` -python train.py --dataset path/to/dataset --mixup_rate 0.5 --gpu 0 +python train.py --dataset path/to/dataset --mixup_rate 0.5 --reduction_rate 0.5 --gpu 0 ``` ## References - [1] Jansson et al., "Singing Voice Separation with Deep U-Net Convolutional Networks", https://ejhumphrey.com/assets/pdf/jansson2017singing.pdf - [2] Takahashi et al., "Multi-scale Multi-band DenseNets for Audio Source Separation", https://arxiv.org/pdf/1706.09588.pdf - [3] Takahashi et al., "MMDENSELSTM: AN EFFICIENT COMBINATION OF CONVOLUTIONAL AND RECURRENT NEURAL NETWORKS FOR AUDIO SOURCE SEPARATION", https://arxiv.org/pdf/1805.02410.pdf -- [4] Jansson et al., "Learned complex masks for multi-instrument source separation", https://arxiv.org/pdf/2103.12864.pdf -- [5] Liutkus et al., "The 2016 Signal Separation Evaluation Campaign", Latent Variable Analysis and Signal Separation - 12th International Conference +- [4] Choi et al., "PHASE-AWARE SPEECH ENHANCEMENT WITH DEEP COMPLEX U-NET", https://openreview.net/pdf?id=SkeRTsAcYm +- [5] Jansson et al., "Learned complex masks for multi-instrument source separation", https://arxiv.org/pdf/2103.12864.pdf +- [6] Liutkus et al., "The 2016 Signal Separation Evaluation Campaign", Latent Variable Analysis and Signal Separation - 12th International Conference diff --git a/inference.py b/inference.py index 8a5f142..f583841 100644 --- a/inference.py +++ b/inference.py @@ -15,7 +15,7 @@ class Separator(object): - def __init__(self, model, device, batchsize, cropsize, postprocess=False): + def __init__(self, model, device=None, batchsize=1, cropsize=256, postprocess=False): self.model = model self.offset = model.offset self.device = device @@ -27,10 +27,15 @@ def _postprocess(self, X_spec, mask): if self.postprocess: mask_mag = np.abs(mask) mask_mag = spec_utils.merge_artifacts(mask_mag) - mask = mask_mag + 1.j * np.exp(np.angle(mask)) + mask = mask_mag * np.exp(1.j * np.angle(mask)) - y_spec = X_spec * mask - v_spec = X_spec - y_spec + X_mag = np.abs(X_spec) + X_phase = np.angle(X_spec) + + y_spec = mask * X_mag * np.exp(1.j * X_phase) + v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase) + # y_spec = X_spec * mask + # v_spec = X_spec - y_spec return y_spec, v_spec @@ -52,7 +57,7 @@ def _separate(self, X_spec_pad, roi_size): X_batch = X_dataset[i: i + self.batchsize] X_batch = torch.from_numpy(X_batch).to(self.device) - mask = self.model.predict_mask(X_batch) + mask = self.model.predict_mask(torch.abs(X_batch)) mask = mask.detach().cpu().numpy() mask = np.concatenate(mask, axis=2) @@ -109,26 +114,26 @@ def main(): p.add_argument('--cropsize', '-c', type=int, default=256) p.add_argument('--output_image', '-I', action='store_true') p.add_argument('--tta', '-t', action='store_true') - # p.add_argument('--postprocess', '-p', action='store_true') + p.add_argument('--postprocess', '-p', action='store_true') p.add_argument('--output_dir', '-o', type=str, default="") args = p.parse_args() print('loading model...', end=' ') device = torch.device('cpu') - model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128) - model.load_state_dict(torch.load(args.pretrained_model, map_location=device)) if args.gpu >= 0: if torch.cuda.is_available(): device = torch.device('cuda:{}'.format(args.gpu)) - model.to(device) elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): device = torch.device('mps') - model.to(device) + model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128) + model.load_state_dict(torch.load(args.pretrained_model, map_location='cpu')) + model.to(device) print('done') print('loading wave source...', end=' ') X, sr = librosa.load( - args.input, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast') + args.input, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast' + ) basename = os.path.splitext(os.path.basename(args.input))[0] print('done') @@ -140,7 +145,13 @@ def main(): X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft) print('done') - sp = Separator(model, device, args.batchsize, args.cropsize) + sp = Separator( + model=model, + device=device, + batchsize=args.batchsize, + cropsize=args.cropsize, + postprocess=args.postprocess + ) if args.tta: y_spec, v_spec = sp.separate_tta(X_spec) diff --git a/lib/dataset.py b/lib/dataset.py index 1c070b9..4cdf8a8 100644 --- a/lib/dataset.py +++ b/lib/dataset.py @@ -46,6 +46,16 @@ def read_npy_chunk(self, path, start_row): return flat.reshape((-1,) + shape[1:]) + def aggressively_remove_vocal(self, X, y): + X_mag = np.abs(X) + y_mag = np.abs(y) + v_mag = X_mag - y_mag + v_mag *= v_mag > y_mag + + y_mag = np.clip(y_mag - v_mag * self.reduction_weight, 0, np.inf) + + return y_mag * np.exp(1.j * np.angle(y)) + def do_crop(self, X_path, y_path): shape = self.read_npy_shape(X_path) start_row = np.random.randint(0, shape[0] - self.cropsize) @@ -57,7 +67,7 @@ def do_crop(self, X_path, y_path): def do_aug(self, X, y): if np.random.uniform() < self.reduction_rate: - y = spec_utils.aggressively_remove_vocal(X, y, self.reduction_weight) + y = self.aggressively_remove_vocal(X, y) if np.random.uniform() < 0.5: # swap channel @@ -103,7 +113,11 @@ def __getitem__(self, idx): if np.random.uniform() < self.mixup_rate: X, y = self.do_mixup(X, y) - return X, y + X_mag = np.abs(X) + y_mag = np.abs(y) + + return X_mag, y_mag + # return X, y class VocalRemoverValidationSet(torch.utils.data.Dataset): @@ -120,7 +134,11 @@ def __getitem__(self, idx): X, y = data['X'], data['y'] - return X, y + X_mag = np.abs(X) + y_mag = np.abs(y) + + return X_mag, y_mag + # return X, y def make_pair(mix_dir, inst_dir): diff --git a/lib/layers.py b/lib/layers.py index 7bc0b7c..e184123 100644 --- a/lib/layers.py +++ b/lib/layers.py @@ -26,33 +26,6 @@ def __call__(self, x): return self.conv(x) -# class SeperableConv2DBNActiv(nn.Module): - -# def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU): -# super(SeperableConv2DBNActiv, self).__init__() -# self.conv = nn.Sequential( -# nn.Conv2d( -# nin, nin, -# kernel_size=ksize, -# stride=stride, -# padding=pad, -# dilation=dilation, -# groups=nin, -# bias=False -# ), -# nn.Conv2d( -# nin, nout, -# kernel_size=1, -# bias=False -# ), -# nn.BatchNorm2d(nout), -# activ() -# ) - -# def __call__(self, x): -# return self.conv(x) - - class Encoder(nn.Module): def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU): diff --git a/lib/nets.py b/lib/nets.py index cb02e2f..351fba7 100644 --- a/lib/nets.py +++ b/lib/nets.py @@ -43,40 +43,46 @@ def __call__(self, x): class CascadedNet(nn.Module): - def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128): + def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=False): super(CascadedNet, self).__init__() self.n_fft = n_fft self.hop_length = hop_length + self.is_complex = is_complex + self.max_bin = n_fft // 2 self.output_bin = n_fft // 2 + 1 self.nin_lstm = self.max_bin // 2 self.offset = 64 + nin = 4 if is_complex else 2 + self.stg1_low_band_net = nn.Sequential( - BaseNet(4, nout // 2, self.nin_lstm // 2, nout_lstm), + BaseNet(nin, nout // 2, self.nin_lstm // 2, nout_lstm), layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0) ) self.stg1_high_band_net = BaseNet( - 4, nout // 4, self.nin_lstm // 2, nout_lstm // 2 + nin, nout // 4, self.nin_lstm // 2, nout_lstm // 2 ) self.stg2_low_band_net = nn.Sequential( - BaseNet(nout // 4 + 4, nout, self.nin_lstm // 2, nout_lstm), + BaseNet(nout // 4 + nin, nout, self.nin_lstm // 2, nout_lstm), layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0) ) self.stg2_high_band_net = BaseNet( - nout // 4 + 4, nout // 2, self.nin_lstm // 2, nout_lstm // 2 + nout // 4 + nin, nout // 2, self.nin_lstm // 2, nout_lstm // 2 ) self.stg3_full_band_net = BaseNet( - 3 * nout // 4 + 4, nout, self.nin_lstm, nout_lstm + 3 * nout // 4 + nin, nout, self.nin_lstm, nout_lstm ) - self.out = nn.Conv2d(nout, 4, 1, bias=False) - self.aux_out = nn.Conv2d(3 * nout // 4, 4, 1, bias=False) + self.out = nn.Conv2d(nout, nin, 1, bias=False) + self.aux_out = nn.Conv2d(3 * nout // 4, nin, 1, bias=False) def forward(self, x): - x = torch.cat([x.real, x.imag], dim=1) + if self.is_complex: + x = torch.cat([x.real, x.imag], dim=1) + x = x[:, :, :self.max_bin] bandw = x.size()[2] // 2 @@ -95,26 +101,25 @@ def forward(self, x): f3_in = torch.cat([x, aux1, aux2], dim=1) f3 = self.stg3_full_band_net(f3_in) - mask = torch.tanh(self.out(f3)) - mask = torch.complex(mask[:, :2], mask[:, 2:]) + if self.is_complex: + mask = self.out(f3) + mask = torch.complex(mask[:, :2], mask[:, 2:]) + mask = self.bounded_mask(mask) + else: + mask = torch.sigmoid(self.out(f3)) + mask = F.pad( input=mask, pad=(0, 0, 0, self.output_bin - mask.size()[2]), mode='replicate' ) - if self.training: - aux = torch.cat([aux1, aux2], dim=1) - aux = torch.tanh(self.aux_out(aux)) - aux = torch.complex(aux[:, :2], aux[:, 2:]) - aux = F.pad( - input=aux, - pad=(0, 0, 0, self.output_bin - aux.size()[2]), - mode='replicate' - ) - return mask, aux - else: - return mask + return mask + + def bounded_mask(self, mask, eps=1e-8): + mask_mag = torch.abs(mask) + mask = torch.tanh(mask_mag) * mask / (mask_mag + eps) + return mask def predict_mask(self, x): mask = self.forward(x) diff --git a/lib/spec_utils.py b/lib/spec_utils.py index 7b74982..5822420 100644 --- a/lib/spec_utils.py +++ b/lib/spec_utils.py @@ -57,18 +57,6 @@ def spectrogram_to_image(spec, mode='magnitude'): return img -def aggressively_remove_vocal(X, y, weight): - X_mag = np.abs(X) - y_mag = np.abs(y) - # v_mag = np.abs(X_mag - y_mag) - v_mag = X_mag - y_mag - v_mag *= v_mag > y_mag - - y_mag = np.clip(y_mag - v_mag * weight, 0, np.inf) - - return y_mag * np.exp(1.j * np.angle(y)) - - def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32): if min_range < fade_size * 2: raise ValueError('min_range must be >= fade_size * 2') @@ -182,13 +170,19 @@ def spectrogram_to_wave(spec, hop_length=1024): import sys X, _ = librosa.load( - sys.argv[1], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast') + sys.argv[1], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast' + ) y, _ = librosa.load( - sys.argv[2], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast') + sys.argv[2], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast' + ) X, y = align_wave_head_and_tail(X, y, 44100) X_spec = wave_to_spectrogram(X, 1024, 2048) y_spec = wave_to_spectrogram(y, 1024, 2048) + + # X_spec = np.load(sys.argv[1]).transpose(1, 2, 0) + # y_spec = np.load(sys.argv[2]).transpose(1, 2, 0) + v_spec = X_spec - y_spec X_image = spectrogram_to_image(X_spec) diff --git a/train.py b/train.py index b2af717..7735203 100644 --- a/train.py +++ b/train.py @@ -22,8 +22,7 @@ def setup_logger(name, logfile='LOGFILENAME.log'): fh = logging.FileHandler(logfile, encoding='utf8') fh.setLevel(logging.DEBUG) - fh_formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - %(message)s') + fh_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') fh.setFormatter(fh_formatter) sh = logging.StreamHandler() @@ -35,48 +34,60 @@ def setup_logger(name, logfile='LOGFILENAME.log'): return logger +def to_wave(spec, n_fft, hop_length, window): + B, _, N, T = spec.shape + wave = spec.reshape(-1, N, T) + wave = torch.istft(wave, n_fft, hop_length, window=window) + wave = wave.reshape(B, 2, -1) + + return wave + + +def sdr_loss(y, y_pred, eps=1e-8): + sdr = (y * y_pred).sum() + sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps + + return -sdr + + +def weighted_sdr_loss(y, y_pred, n, n_pred, eps=1e-8): + y_sdr = (y * y_pred).sum() + y_sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps + + noise_sdr = (n * n_pred).sum() + noise_sdr /= torch.linalg.norm(n) * torch.linalg.norm(n_pred) + eps + + a = torch.sum(y ** 2) + a /= torch.sum(y ** 2) + torch.sum(n ** 2) + eps + + loss = a * y_sdr + (1 - a) * noise_sdr + + return -loss + + def train_epoch(dataloader, model, device, optimizer, accumulation_steps): model.train() - n_fft = model.n_fft - hop_length = model.hop_length - window = torch.hann_window(n_fft).to(device) + # n_fft = model.n_fft + # hop_length = model.hop_length + # window = torch.hann_window(n_fft).to(device) sum_loss = 0 - crit = nn.L1Loss() + crit_l1 = nn.L1Loss() for itr, (X_batch, y_batch) in enumerate(dataloader): X_batch = X_batch.to(device) y_batch = y_batch.to(device) - mask, aux = model(X_batch) - - pred = X_batch * mask - aux = X_batch * aux - - y_mag_batch = torch.abs(y_batch) - - B, C, N, T = X_batch.shape - y_wave_batch = y_batch.reshape(-1, N, T) - y_wave_batch = torch.istft(y_wave_batch, n_fft, hop_length, window=window) - y_wave_batch = y_wave_batch.reshape(B, 2, -1) + mask = model(X_batch) - pred_wave = pred.reshape(-1, N, T) - pred_wave = torch.istft(pred_wave, n_fft, hop_length, window=window).reshape(B, 2, -1) - pred_wave = pred_wave.reshape(B, 2, -1) - pred_sdr_inner = (pred_wave * y_wave_batch).sum() - pred_sdr_norm = torch.linalg.norm(pred_wave) * torch.linalg.norm(y_wave_batch) + # y_pred = X_batch * mask + # y_wave_batch = to_wave(y_batch, n_fft, hop_length, window) + # y_wave_pred = to_wave(y_pred, n_fft, hop_length, window) - loss_main = crit(torch.abs(pred), y_mag_batch) - (pred_sdr_inner / pred_sdr_norm) * 1e-2 + # loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred)) + # loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01 + loss = crit_l1(mask * X_batch, y_batch) - aux_wave = aux.reshape(-1, N, T) - aux_wave = torch.istft(aux_wave, n_fft, hop_length, window=window).reshape(B, 2, -1) - aux_wave = aux_wave.reshape(B, 2, -1) - aux_sdr_inner = (aux_wave * y_wave_batch).sum() - aux_sdr_norm = torch.linalg.norm(aux_wave) * torch.linalg.norm(y_wave_batch) - - loss_aux = crit(torch.abs(aux), y_mag_batch) - (aux_sdr_inner / aux_sdr_norm) * 1e-2 - - loss = loss_main * 0.8 + loss_aux * 0.2 accum_loss = loss / accumulation_steps accum_loss.backward() @@ -96,35 +107,27 @@ def train_epoch(dataloader, model, device, optimizer, accumulation_steps): def validate_epoch(dataloader, model, device): model.eval() - n_fft = model.n_fft - hop_length = model.hop_length - window = torch.hann_window(n_fft).to(device) + # n_fft = model.n_fft + # hop_length = model.hop_length + # window = torch.hann_window(n_fft).to(device) sum_loss = 0 - crit = nn.L1Loss() + crit_l1 = nn.L1Loss() with torch.no_grad(): for X_batch, y_batch in dataloader: X_batch = X_batch.to(device) y_batch = y_batch.to(device) - pred = model.predict(X_batch) - - y_batch = spec_utils.crop_center(y_batch, pred) - y_mag_batch = torch.abs(y_batch) + y_pred = model.predict(X_batch) - B, C, N, T = X_batch.shape - y_wave_batch = y_batch.reshape(-1, N, T) - y_wave_batch = torch.istft(y_wave_batch, n_fft, hop_length, window=window) - y_wave_batch = y_wave_batch.reshape(B, 2, -1) + y_batch = spec_utils.crop_center(y_batch, y_pred) + # y_wave_batch = to_wave(y_batch, n_fft, hop_length, window) + # y_wave_pred = to_wave(y_pred, n_fft, hop_length, window) - pred_wave = pred.reshape(-1, N, T) - pred_wave = torch.istft(pred_wave, n_fft, hop_length, window=window).reshape(B, 2, -1) - pred_wave = pred_wave.reshape(B, 2, -1) - pred_sdr_inner = (pred_wave * y_wave_batch).sum() - pred_sdr_norm = torch.linalg.norm(pred_wave) * torch.linalg.norm(y_wave_batch) - - loss = crit(torch.abs(pred), y_mag_batch) - (pred_sdr_inner / pred_sdr_norm) * 1e-2 + # loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred)) + # loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01 + loss = crit_l1(y_pred, y_batch) sum_loss += loss.item() * len(X_batch) @@ -191,6 +194,16 @@ def main(): for i, (X_fname, y_fname) in enumerate(val_filelist): logger.info('{} {} {}'.format(i + 1, os.path.basename(X_fname), os.path.basename(y_fname))) + bins = args.n_fft // 2 + 1 + freq_to_bin = 2 * bins / args.sr + unstable_bins = int(200 * freq_to_bin) + stable_bins = int(22050 * freq_to_bin) + reduction_weight = np.concatenate([ + np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None], + np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None], + np.zeros((bins - stable_bins, 1), dtype=np.float32), + ], axis=0) * args.reduction_level + device = torch.device('cpu') model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128) if args.pretrained_model is not None: @@ -213,16 +226,6 @@ def main(): verbose=True ) - bins = args.n_fft // 2 + 1 - freq_to_bin = 2 * bins / args.sr - unstable_bins = int(200 * freq_to_bin) - stable_bins = int(22050 * freq_to_bin) - reduction_weight = np.concatenate([ - np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None], - np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None], - np.zeros((bins - stable_bins, 1), dtype=np.float32), - ], axis=0) * args.reduction_level - training_set = dataset.make_training_set( filelist=train_filelist, sr=args.sr,