Skip to content

Commit

Permalink
Merge pull request #163 from tsurumeso/feature/revert_v6
Browse files Browse the repository at this point in the history
Revert v6 from develop
  • Loading branch information
tsurumeso authored Nov 30, 2023
2 parents 1f69e49 + b750f37 commit 8a02fc5
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 145 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@

*~
.vscode/
venv/
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ python inference.py --input path/to/an/audio/file --gpu 0
python inference.py --input path/to/an/audio/file --tta --gpu 0
```

<!-- `--postprocess` option masks instrumental part based on the vocals volume to improve the separation quality.
`--postprocess` option masks instrumental part based on the vocals volume to improve the separation quality.
**Experimental Warning**: If you get any problems with this option, please disable it.
```
python inference.py --input path/to/an/audio/file --postprocess --gpu 0
``` -->
```

## Train your own model

Expand All @@ -61,12 +61,13 @@ path/to/dataset/

### Train a model
```
python train.py --dataset path/to/dataset --mixup_rate 0.5 --gpu 0
python train.py --dataset path/to/dataset --mixup_rate 0.5 --reduction_rate 0.5 --gpu 0
```

## References
- [1] Jansson et al., "Singing Voice Separation with Deep U-Net Convolutional Networks", https://ejhumphrey.com/assets/pdf/jansson2017singing.pdf
- [2] Takahashi et al., "Multi-scale Multi-band DenseNets for Audio Source Separation", https://arxiv.org/pdf/1706.09588.pdf
- [3] Takahashi et al., "MMDENSELSTM: AN EFFICIENT COMBINATION OF CONVOLUTIONAL AND RECURRENT NEURAL NETWORKS FOR AUDIO SOURCE SEPARATION", https://arxiv.org/pdf/1805.02410.pdf
- [4] Jansson et al., "Learned complex masks for multi-instrument source separation", https://arxiv.org/pdf/2103.12864.pdf
- [5] Liutkus et al., "The 2016 Signal Separation Evaluation Campaign", Latent Variable Analysis and Signal Separation - 12th International Conference
- [4] Choi et al., "PHASE-AWARE SPEECH ENHANCEMENT WITH DEEP COMPLEX U-NET", https://openreview.net/pdf?id=SkeRTsAcYm
- [5] Jansson et al., "Learned complex masks for multi-instrument source separation", https://arxiv.org/pdf/2103.12864.pdf
- [6] Liutkus et al., "The 2016 Signal Separation Evaluation Campaign", Latent Variable Analysis and Signal Separation - 12th International Conference
35 changes: 23 additions & 12 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class Separator(object):

def __init__(self, model, device, batchsize, cropsize, postprocess=False):
def __init__(self, model, device=None, batchsize=1, cropsize=256, postprocess=False):
self.model = model
self.offset = model.offset
self.device = device
Expand All @@ -27,10 +27,15 @@ def _postprocess(self, X_spec, mask):
if self.postprocess:
mask_mag = np.abs(mask)
mask_mag = spec_utils.merge_artifacts(mask_mag)
mask = mask_mag + 1.j * np.exp(np.angle(mask))
mask = mask_mag * np.exp(1.j * np.angle(mask))

y_spec = X_spec * mask
v_spec = X_spec - y_spec
X_mag = np.abs(X_spec)
X_phase = np.angle(X_spec)

y_spec = mask * X_mag * np.exp(1.j * X_phase)
v_spec = (1 - mask) * X_mag * np.exp(1.j * X_phase)
# y_spec = X_spec * mask
# v_spec = X_spec - y_spec

return y_spec, v_spec

Expand All @@ -52,7 +57,7 @@ def _separate(self, X_spec_pad, roi_size):
X_batch = X_dataset[i: i + self.batchsize]
X_batch = torch.from_numpy(X_batch).to(self.device)

mask = self.model.predict_mask(X_batch)
mask = self.model.predict_mask(torch.abs(X_batch))

mask = mask.detach().cpu().numpy()
mask = np.concatenate(mask, axis=2)
Expand Down Expand Up @@ -109,26 +114,26 @@ def main():
p.add_argument('--cropsize', '-c', type=int, default=256)
p.add_argument('--output_image', '-I', action='store_true')
p.add_argument('--tta', '-t', action='store_true')
# p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--postprocess', '-p', action='store_true')
p.add_argument('--output_dir', '-o', type=str, default="")
args = p.parse_args()

print('loading model...', end=' ')
device = torch.device('cpu')
model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128)
model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
if args.gpu >= 0:
if torch.cuda.is_available():
device = torch.device('cuda:{}'.format(args.gpu))
model.to(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device('mps')
model.to(device)
model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128)
model.load_state_dict(torch.load(args.pretrained_model, map_location='cpu'))
model.to(device)
print('done')

print('loading wave source...', end=' ')
X, sr = librosa.load(
args.input, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast')
args.input, sr=args.sr, mono=False, dtype=np.float32, res_type='kaiser_fast'
)
basename = os.path.splitext(os.path.basename(args.input))[0]
print('done')

Expand All @@ -140,7 +145,13 @@ def main():
X_spec = spec_utils.wave_to_spectrogram(X, args.hop_length, args.n_fft)
print('done')

sp = Separator(model, device, args.batchsize, args.cropsize)
sp = Separator(
model=model,
device=device,
batchsize=args.batchsize,
cropsize=args.cropsize,
postprocess=args.postprocess
)

if args.tta:
y_spec, v_spec = sp.separate_tta(X_spec)
Expand Down
24 changes: 21 additions & 3 deletions lib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def read_npy_chunk(self, path, start_row):

return flat.reshape((-1,) + shape[1:])

def aggressively_remove_vocal(self, X, y):
X_mag = np.abs(X)
y_mag = np.abs(y)
v_mag = X_mag - y_mag
v_mag *= v_mag > y_mag

y_mag = np.clip(y_mag - v_mag * self.reduction_weight, 0, np.inf)

return y_mag * np.exp(1.j * np.angle(y))

def do_crop(self, X_path, y_path):
shape = self.read_npy_shape(X_path)
start_row = np.random.randint(0, shape[0] - self.cropsize)
Expand All @@ -57,7 +67,7 @@ def do_crop(self, X_path, y_path):

def do_aug(self, X, y):
if np.random.uniform() < self.reduction_rate:
y = spec_utils.aggressively_remove_vocal(X, y, self.reduction_weight)
y = self.aggressively_remove_vocal(X, y)

if np.random.uniform() < 0.5:
# swap channel
Expand Down Expand Up @@ -103,7 +113,11 @@ def __getitem__(self, idx):
if np.random.uniform() < self.mixup_rate:
X, y = self.do_mixup(X, y)

return X, y
X_mag = np.abs(X)
y_mag = np.abs(y)

return X_mag, y_mag
# return X, y


class VocalRemoverValidationSet(torch.utils.data.Dataset):
Expand All @@ -120,7 +134,11 @@ def __getitem__(self, idx):

X, y = data['X'], data['y']

return X, y
X_mag = np.abs(X)
y_mag = np.abs(y)

return X_mag, y_mag
# return X, y


def make_pair(mix_dir, inst_dir):
Expand Down
27 changes: 0 additions & 27 deletions lib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,6 @@ def __call__(self, x):
return self.conv(x)


# class SeperableConv2DBNActiv(nn.Module):

# def __init__(self, nin, nout, ksize=3, stride=1, pad=1, dilation=1, activ=nn.ReLU):
# super(SeperableConv2DBNActiv, self).__init__()
# self.conv = nn.Sequential(
# nn.Conv2d(
# nin, nin,
# kernel_size=ksize,
# stride=stride,
# padding=pad,
# dilation=dilation,
# groups=nin,
# bias=False
# ),
# nn.Conv2d(
# nin, nout,
# kernel_size=1,
# bias=False
# ),
# nn.BatchNorm2d(nout),
# activ()
# )

# def __call__(self, x):
# return self.conv(x)


class Encoder(nn.Module):

def __init__(self, nin, nout, ksize=3, stride=1, pad=1, activ=nn.LeakyReLU):
Expand Down
51 changes: 28 additions & 23 deletions lib/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,46 @@ def __call__(self, x):

class CascadedNet(nn.Module):

def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128):
def __init__(self, n_fft, hop_length, nout=32, nout_lstm=128, is_complex=False):
super(CascadedNet, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.is_complex = is_complex

self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.nin_lstm = self.max_bin // 2
self.offset = 64

nin = 4 if is_complex else 2

self.stg1_low_band_net = nn.Sequential(
BaseNet(4, nout // 2, self.nin_lstm // 2, nout_lstm),
BaseNet(nin, nout // 2, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(
4, nout // 4, self.nin_lstm // 2, nout_lstm // 2
nin, nout // 4, self.nin_lstm // 2, nout_lstm // 2
)

self.stg2_low_band_net = nn.Sequential(
BaseNet(nout // 4 + 4, nout, self.nin_lstm // 2, nout_lstm),
BaseNet(nout // 4 + nin, nout, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(
nout // 4 + 4, nout // 2, self.nin_lstm // 2, nout_lstm // 2
nout // 4 + nin, nout // 2, self.nin_lstm // 2, nout_lstm // 2
)

self.stg3_full_band_net = BaseNet(
3 * nout // 4 + 4, nout, self.nin_lstm, nout_lstm
3 * nout // 4 + nin, nout, self.nin_lstm, nout_lstm
)

self.out = nn.Conv2d(nout, 4, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 4, 1, bias=False)
self.out = nn.Conv2d(nout, nin, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, nin, 1, bias=False)

def forward(self, x):
x = torch.cat([x.real, x.imag], dim=1)
if self.is_complex:
x = torch.cat([x.real, x.imag], dim=1)

x = x[:, :, :self.max_bin]

bandw = x.size()[2] // 2
Expand All @@ -95,26 +101,25 @@ def forward(self, x):
f3_in = torch.cat([x, aux1, aux2], dim=1)
f3 = self.stg3_full_band_net(f3_in)

mask = torch.tanh(self.out(f3))
mask = torch.complex(mask[:, :2], mask[:, 2:])
if self.is_complex:
mask = self.out(f3)
mask = torch.complex(mask[:, :2], mask[:, 2:])
mask = self.bounded_mask(mask)
else:
mask = torch.sigmoid(self.out(f3))

mask = F.pad(
input=mask,
pad=(0, 0, 0, self.output_bin - mask.size()[2]),
mode='replicate'
)

if self.training:
aux = torch.cat([aux1, aux2], dim=1)
aux = torch.tanh(self.aux_out(aux))
aux = torch.complex(aux[:, :2], aux[:, 2:])
aux = F.pad(
input=aux,
pad=(0, 0, 0, self.output_bin - aux.size()[2]),
mode='replicate'
)
return mask, aux
else:
return mask
return mask

def bounded_mask(self, mask, eps=1e-8):
mask_mag = torch.abs(mask)
mask = torch.tanh(mask_mag) * mask / (mask_mag + eps)
return mask

def predict_mask(self, x):
mask = self.forward(x)
Expand Down
22 changes: 8 additions & 14 deletions lib/spec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,6 @@ def spectrogram_to_image(spec, mode='magnitude'):
return img


def aggressively_remove_vocal(X, y, weight):
X_mag = np.abs(X)
y_mag = np.abs(y)
# v_mag = np.abs(X_mag - y_mag)
v_mag = X_mag - y_mag
v_mag *= v_mag > y_mag

y_mag = np.clip(y_mag - v_mag * weight, 0, np.inf)

return y_mag * np.exp(1.j * np.angle(y))


def merge_artifacts(y_mask, thres=0.05, min_range=64, fade_size=32):
if min_range < fade_size * 2:
raise ValueError('min_range must be >= fade_size * 2')
Expand Down Expand Up @@ -182,13 +170,19 @@ def spectrogram_to_wave(spec, hop_length=1024):
import sys

X, _ = librosa.load(
sys.argv[1], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast')
sys.argv[1], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast'
)
y, _ = librosa.load(
sys.argv[2], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast')
sys.argv[2], sr=44100, mono=False, dtype=np.float32, res_type='kaiser_fast'
)

X, y = align_wave_head_and_tail(X, y, 44100)
X_spec = wave_to_spectrogram(X, 1024, 2048)
y_spec = wave_to_spectrogram(y, 1024, 2048)

# X_spec = np.load(sys.argv[1]).transpose(1, 2, 0)
# y_spec = np.load(sys.argv[2]).transpose(1, 2, 0)

v_spec = X_spec - y_spec

X_image = spectrogram_to_image(X_spec)
Expand Down
Loading

0 comments on commit 8a02fc5

Please sign in to comment.