From 2c1ef2f64d1cc0854366c8beb02d413010620731 Mon Sep 17 00:00:00 2001 From: Ben Hayes Date: Thu, 9 Feb 2023 18:32:19 +0000 Subject: [PATCH 1/5] Begin dpt impl --- aimless/models/dpt.py | 61 ++++++++++++++++++++++++++++++++++++++++ setup.cfg | 5 ++++ setup.py | 1 + tests/__init__.py | 0 tests/models/__init__.py | 0 tests/models/test_dpt.py | 45 +++++++++++++++++++++++++++++ 6 files changed, 112 insertions(+) create mode 100644 aimless/models/dpt.py create mode 100644 setup.cfg create mode 100644 tests/__init__.py create mode 100644 tests/models/__init__.py create mode 100644 tests/models/test_dpt.py diff --git a/aimless/models/dpt.py b/aimless/models/dpt.py new file mode 100644 index 0000000..0f2e431 --- /dev/null +++ b/aimless/models/dpt.py @@ -0,0 +1,61 @@ +import torch +from torch import nn + +__all__ = ["DPTLayer"] + + +def _get_activation(activation: str): + try: + return getattr(torch.nn, activation)() + except AttributeError: + raise ValueError(f"Activation {activation} not supported.") + + +class DPTLayer(nn.Module): + """Implements a layer of the Dual-Path Transformer, as described in [1] + + Args: + embedding_size (int): The size of the input embedding. + num_heads (int): The number of attention heads. + hidden_size (int): The size of the hidden layer in the LSTM. + dropout (float): The dropout rate. + activation (str): The activation function to use. One of ["relu", "gelu"]. + **lstm_kwargs: Additional keyword arguments to pass to the LSTM. + + References: + [1] https://arxiv.org/abs/2007.13975 + """ + + def __init__( + self, + embedding_size: int, + num_heads: int, + hidden_size: int, + dropout: float, + activation: str, + **lstm_kwargs, + ): + super().__init__() + + self.attention = nn.MultiheadAttention( + embedding_size, num_heads, dropout=dropout, batch_first=True + ) + self.norm1 = nn.LayerNorm(embedding_size) + + self.lstm = nn.LSTM( + embedding_size, hidden_size, batch_first=True, **lstm_kwargs + ) + self.dense = nn.Sequential( + nn.Linear(hidden_size, embedding_size), _get_activation(activation) + ) + self.norm2 = nn.LayerNorm(embedding_size) + + def forward(self, x: torch.Tensor): + x = x + self.attention(x, x, x)[0] + x = self.norm1(x) + + h, _ = self.lstm(x) + x = x + self.dense(h) + x = self.norm2(x) + + return x diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..ed2418a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +# make flake8 play nicely with black +[flake8] +max-line-length = 88 +extend-ignore = E203 + diff --git a/setup.py b/setup.py index d01d7a5..fb38b61 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ author_email="chin-yun.yu@qmul.ac.uk", packages=setuptools.find_packages(exclude=["tests", "tests.*", "data", "data.*"]), install_requires=["torch", "pytorch-lightning", "torch_fftconv"], + extras_require={"dev": ["pytest", "pytest-mock"]}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_dpt.py b/tests/models/test_dpt.py new file mode 100644 index 0000000..89c97f0 --- /dev/null +++ b/tests/models/test_dpt.py @@ -0,0 +1,45 @@ +import pytest +import torch + +from aimless.models import dpt + + +@pytest.fixture(params=[10]) +def embedding_size(request): + return request.param + + +@pytest.fixture +def dpt_layer(embedding_size): + return dpt.DPTLayer(embedding_size, 2, 20, 0.1, "ReLU") + + +def test_dpt_layer_forward(dpt_layer, embedding_size): + batch_size = 10 + seq_len = 20 + inputs = torch.testing.make_tensor( + batch_size, seq_len, embedding_size, dtype=torch.float32, device="cpu" + ) + + outputs = dpt_layer(inputs) + + assert outputs.shape == (batch_size, seq_len, embedding_size) + + +def test_dpt_layer_transforms_input(dpt_layer, embedding_size): + batch_size = 11 + seq_len = 21 + inputs = torch.testing.make_tensor( + batch_size, seq_len, embedding_size, dtype=torch.float32, device="cpu" + ) + + outputs = dpt_layer(inputs) + + assert not torch.allclose(inputs, outputs) + + +@pytest.mark.parametrize("activation", ["ReLU", "GELU", "ELU", "LeakyReLU", "Tanh"]) +def test_dpt_layer_uses_selected_activation(mocker, activation): + spy = mocker.spy(dpt.torch.nn, activation) + dpt.DPTLayer(10, 2, 20, 0.1, activation) + spy.assert_called_once() From 4c4033f935c75d9a120fa750caebbb6caddb1c60 Mon Sep 17 00:00:00 2001 From: Ben Hayes Date: Fri, 10 Feb 2023 17:43:17 +0000 Subject: [PATCH 2/5] Implement DPT --- aimless/models/dpt.py | 258 ++++++++++++++++++++++++++++++++++++++- setup.py | 2 +- tests/models/test_dpt.py | 72 +++++++++++ 3 files changed, 325 insertions(+), 7 deletions(-) diff --git a/aimless/models/dpt.py b/aimless/models/dpt.py index 0f2e431..9b417c9 100644 --- a/aimless/models/dpt.py +++ b/aimless/models/dpt.py @@ -1,8 +1,9 @@ +from typing import Callable + +from einops import rearrange import torch from torch import nn -__all__ = ["DPTLayer"] - def _get_activation(activation: str): try: @@ -11,15 +12,53 @@ def _get_activation(activation: str): raise ValueError(f"Activation {activation} not supported.") +class DPTFilterbank(nn.Module): + """A Conv-TasNet style filterbank for DPT. + + Args: + input_channels (int): The number of input channels. + num_filters (int): The number of filters to learn. + kernel_size (int): The size of the filters. + nonlinearity (str): The nonlinearity to use. One of ["ReLU", "GELU", "ELU", + "LeakyReLU", "Tanh"]. + """ + + def __init__( + self, + input_channels: int = 1, + num_filters: int = 64, + kernel_size: int = 16, + nonlinearity: str = "ReLU", + transpose: bool = False, + ): + super().__init__() + + _conv_module = nn.ConvTranspose1d if transpose else nn.Conv1d + + self.filterbank = _conv_module( + input_channels, + num_filters, + kernel_size, + stride=kernel_size // 2, + padding=0, + ) + self.nonlinearity = ( + _get_activation(nonlinearity) if nonlinearity is not None else lambda x: x + ) + + def forward(self, x: torch.Tensor): + return self.nonlinearity(self.filterbank(x)) + + class DPTLayer(nn.Module): - """Implements a layer of the Dual-Path Transformer, as described in [1] + """One layer of the Dual-Path Transformer, as described in [1] Args: embedding_size (int): The size of the input embedding. num_heads (int): The number of attention heads. hidden_size (int): The size of the hidden layer in the LSTM. dropout (float): The dropout rate. - activation (str): The activation function to use. One of ["relu", "gelu"]. + activation (str): The activation function to use in `nn.modules.activations`. **lstm_kwargs: Additional keyword arguments to pass to the LSTM. References: @@ -33,20 +72,34 @@ def __init__( hidden_size: int, dropout: float, activation: str, + bidirectional: bool = True, **lstm_kwargs, ): super().__init__() + if embedding_size % num_heads != 0: + raise ValueError( + f"Embedding size {embedding_size} must be divisible by number of " + f"heads {num_heads}." + ) + self.attention = nn.MultiheadAttention( embedding_size, num_heads, dropout=dropout, batch_first=True ) self.norm1 = nn.LayerNorm(embedding_size) self.lstm = nn.LSTM( - embedding_size, hidden_size, batch_first=True, **lstm_kwargs + embedding_size, + hidden_size, + batch_first=True, + bidirectional=bidirectional, + **lstm_kwargs, ) self.dense = nn.Sequential( - nn.Linear(hidden_size, embedding_size), _get_activation(activation) + nn.Linear( + hidden_size * 2 if bidirectional else hidden_size, embedding_size + ), + _get_activation(activation), ) self.norm2 = nn.LayerNorm(embedding_size) @@ -59,3 +112,196 @@ def forward(self, x: torch.Tensor): x = self.norm2(x) return x + + +class DualPathLayer(nn.Module): + """Apply intra- and inter-chunk transformer layers. + + Args: + embedding_size (int): The size of the input embedding. + num_heads (int): The number of attention heads. + hidden_size (int): The size of the hidden layer in the LSTM. + dropout (float): The dropout rate. + activation (str): The activation function to use in `nn.modules.activations`. + **lstm_kwargs: Additional keyword arguments to pass to the LSTM. + """ + + def __init__( + self, + embedding_size: int, + num_heads: int, + hidden_size: int, + dropout: float, + activation: str, + **lstm_kwargs, + ): + super().__init__() + self.intra_chunk = DPTLayer( + embedding_size, num_heads, hidden_size, dropout, activation, **lstm_kwargs + ) + self.inter_chunk = DPTLayer( + embedding_size, num_heads, hidden_size, dropout, activation, **lstm_kwargs + ) + + def _apply_intra_chunk(self, x: torch.Tensor, fn: Callable): + batch_size, *_ = x.shape + x = rearrange(x, "b c m n -> (b n) m c") + x = fn(x) + x = rearrange(x, "(b n) m c -> b c m n", b=batch_size) + return x + + def _apply_inter_chunk(self, x: torch.Tensor, fn: Callable): + batch_size, *_ = x.shape + x = rearrange(x, "b c m n -> (b m) n c") + x = fn(x) + x = rearrange(x, "(b m) n c -> b c m n", b=batch_size) + return x + + def forward(self, x: torch.Tensor): + x = self._apply_intra_chunk(x, self.intra_chunk) + x = self._apply_inter_chunk(x, self.inter_chunk) + return x + + +class DPT(nn.Module): + """The Dual-Path Transformer, as described in [1]. + + Args: + nn (_type_): _description_ + + References: + [1] https://arxiv.org/abs/2007.13975 + """ + + def __init__( + self, + channels: int = 2, + num_sources: int = 4, + num_filters: int = 64, + filter_size: int = 16, + filterbank_nonlinearity: str = "ReLU", + segment_size: int = 100, + segment_stride: int = 50, + num_dual_path_layers: int = 6, + num_attention_heads: int = 4, + lstm_hidden_size: int = 256, + transformer_dropout: float = 0.1, + transformer_nonlinearity: str = "GELU", + post_transformer_prelu: bool = True, + mask_nonlinearity: str = "ReLU", + ): + super().__init__() + self.num_sources = num_sources + + self.encoder = DPTFilterbank( + input_channels=channels, + num_filters=num_filters, + kernel_size=filter_size, + nonlinearity=filterbank_nonlinearity, + ) + self.decoder = DPTFilterbank( + input_channels=num_filters, + num_filters=channels, + kernel_size=filter_size, + nonlinearity=None, + transpose=True, + ) + self.pre_norm = nn.LayerNorm(num_filters) + self.segment_size = segment_size + self.segment_stride = segment_stride + + transformer_net = [] + for _ in range(num_dual_path_layers): + transformer_net.append( + DualPathLayer( + num_filters, + num_attention_heads, + lstm_hidden_size, + dropout=transformer_dropout, + activation=transformer_nonlinearity, + bidirectional=True, + ) + ) + self.transformer_net = nn.Sequential(*transformer_net) + + post_transformer = [] + if post_transformer_prelu: + post_transformer.append(nn.PReLU()) + post_transformer.append(nn.Conv2d(num_filters, num_sources * num_filters, 1)) + self.post_transformer = nn.Sequential(*post_transformer) + + self.gate_paths = nn.ModuleList( + [ + nn.Sequential(nn.Conv1d(num_filters, num_filters, 1), nn.Tanh()), + nn.Sequential(nn.Conv1d(num_filters, num_filters, 1), nn.Sigmoid()), + ] + ) + + self.mask_activation = _get_activation(mask_nonlinearity) + + def _global_norm(self, x: torch.Tensor): + return x / x.norm(dim=1, keepdim=True) + + def _segment(self, x: torch.Tensor): + x = rearrange(x, "b c t -> b c t ()") + x_segmented = nn.functional.unfold( + x, + kernel_size=(self.segment_size, 1), + stride=(self.segment_stride, 1), + padding=(self.segment_size, 0), + ) + x = rearrange(x_segmented, "b (c m) n -> b c m n", m=self.segment_size) + return x + + def _unsegment(self, x: torch.Tensor, original_len: int): + x = rearrange(x, "b c m n -> b (c m) n") + x = nn.functional.fold( + x, + output_size=(original_len, 1), + kernel_size=(self.segment_size, 1), + stride=(self.segment_stride, 1), + padding=(self.segment_size, 0), + ) + x = rearrange(x, "b c t () -> b c t") + return x + + def forward(self, x: torch.Tensor): + # apply input filterbank + x = self.encoder(x) + + # preserve shap for unsegmenting later + *_, original_len = x.shape + + # pre-normalisation + m = self.pre_norm(x.transpose(1, 2)).transpose(1, 2) + + # perform segmentation + m = self._segment(x) + + # apply transformer + m = self.transformer_net(m) + + # project to high dimension + m = self.post_transformer(m) + m = rearrange(m, "b (s c) m n -> (b s) c m n", s=self.num_sources) + + # unsegment + m = self._unsegment(m, original_len) + + # apply gating + m = [g(m) for g in self.gate_paths] + m = torch.mul(*m) + + # reshape to recover masks + m = rearrange(m, "(b s) c t -> b s c t", s=self.num_sources) + m = self.mask_activation(m) + + # apply masks + y_ = m * rearrange(x, "b c t -> b () c t") + + # move sources to batch dimension and apply transposed filterbank + y_ = rearrange(y_, "b s c t -> (b s) c t") + y_ = self.decoder(y_) + y_ = rearrange(y_, "(b s) c t -> b s c t", s=self.num_sources) + + return y_ diff --git a/setup.py b/setup.py index fb38b61..e508766 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ author="Artificial Intelligence and Music League for Effective Source Separation", author_email="chin-yun.yu@qmul.ac.uk", packages=setuptools.find_packages(exclude=["tests", "tests.*", "data", "data.*"]), - install_requires=["torch", "pytorch-lightning", "torch_fftconv"], + install_requires=["torch", "pytorch-lightning", "torch_fftconv", "einops"], extras_require={"dev": ["pytest", "pytest-mock"]}, classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/models/test_dpt.py b/tests/models/test_dpt.py index 89c97f0..5353d3e 100644 --- a/tests/models/test_dpt.py +++ b/tests/models/test_dpt.py @@ -43,3 +43,75 @@ def test_dpt_layer_uses_selected_activation(mocker, activation): spy = mocker.spy(dpt.torch.nn, activation) dpt.DPTLayer(10, 2, 20, 0.1, activation) spy.assert_called_once() + + +def test_dpt_filterbank_forward(): + in_channels = 1 + num_filters = 2 + kernel_size = 3 + + model = dpt.DPTFilterbank(in_channels, num_filters, kernel_size) + + batch_size = 10 + seq_len = 20 + inputs = torch.testing.make_tensor( + batch_size, in_channels, seq_len, dtype=torch.float32, device="cpu" + ) + + outputs = model(inputs) + + assert outputs.shape == (batch_size, num_filters, seq_len - 2 * (kernel_size // 2)) + + +def test_dpt_filterbank_forward_transposed(): + in_channels = 1 + num_filters = 2 + kernel_size = 3 + + model = dpt.DPTFilterbank(num_filters, in_channels, kernel_size, transpose=True) + + batch_size = 10 + seq_len = 20 + inputs = torch.testing.make_tensor( + batch_size, num_filters, seq_len, dtype=torch.float32, device="cpu" + ) + + outputs = model(inputs) + + assert outputs.shape == (batch_size, in_channels, seq_len + 2 * (kernel_size // 2)) + + +def test_dpt_forward(): + batch_size = 3 + input_channels = 1 + seq_len = 32 + + num_sources = 3 + num_filters = 4 + filter_size = 3 + + segment_size = 4 + segment_stride = 2 + + num_dual_path_layers = 2 + num_attention_heads = 2 + + lstm_hidden_size = 16 + + model = dpt.DPT( + channels=input_channels, + num_sources=num_sources, + num_filters=num_filters, + filter_size=filter_size, + segment_size=segment_size, + segment_stride=segment_stride, + num_dual_path_layers=num_dual_path_layers, + num_attention_heads=num_attention_heads, + lstm_hidden_size=lstm_hidden_size, + ) + + x = torch.testing.make_tensor( + batch_size, input_channels, seq_len, dtype=torch.float32, device="cpu" + ) + + y = model(x) From e62a7eb3253fec08c1b6e2849f6cba09ba81359b Mon Sep 17 00:00:00 2001 From: Ben Hayes Date: Fri, 10 Feb 2023 17:44:58 +0000 Subject: [PATCH 3/5] Remove unused method --- aimless/models/dpt.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/aimless/models/dpt.py b/aimless/models/dpt.py index 9b417c9..c808df0 100644 --- a/aimless/models/dpt.py +++ b/aimless/models/dpt.py @@ -239,9 +239,6 @@ def __init__( self.mask_activation = _get_activation(mask_nonlinearity) - def _global_norm(self, x: torch.Tensor): - return x / x.norm(dim=1, keepdim=True) - def _segment(self, x: torch.Tensor): x = rearrange(x, "b c t -> b c t ()") x_segmented = nn.functional.unfold( From 9bae0d86899641fd19a52287f1578aa47b70aa45 Mon Sep 17 00:00:00 2001 From: Ben Hayes Date: Fri, 10 Feb 2023 18:13:54 +0000 Subject: [PATCH 4/5] Add config for DPT --- cfg/dpt.yaml | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 cfg/dpt.yaml diff --git a/cfg/dpt.yaml b/cfg/dpt.yaml new file mode 100644 index 0000000..940cdaf --- /dev/null +++ b/cfg/dpt.yaml @@ -0,0 +1,116 @@ +# pytorch_lightning==1.8.5.post0 +seed_everything: true +trainer: + logger: true + enable_checkpointing: true + callbacks: null + default_root_dir: null + gradient_clip_val: null + gradient_clip_algorithm: null + num_nodes: 1 + num_processes: null + devices: null + gpus: null + auto_select_gpus: false + tpu_cores: null + ipus: null + enable_progress_bar: true + overfit_batches: 0.0 + track_grad_norm: -1 + check_val_every_n_epoch: 1 + fast_dev_run: false + accumulate_grad_batches: null + max_epochs: null + min_epochs: null + max_steps: -1 + min_steps: null + max_time: null + limit_train_batches: null + limit_val_batches: null + limit_test_batches: null + limit_predict_batches: null + val_check_interval: null + log_every_n_steps: 1 + accelerator: gpu + strategy: ddp + sync_batchnorm: true + precision: 32 + enable_model_summary: true + num_sanity_val_steps: 2 + resume_from_checkpoint: null + profiler: null + benchmark: null + deterministic: null + reload_dataloaders_every_n_epochs: 0 + auto_lr_find: false + replace_sampler_ddp: true + detect_anomaly: false + auto_scale_batch_size: false + plugins: null + amp_backend: native + amp_level: null + move_metrics_to_cpu: false + multiple_trainloader_mode: max_size_cycle + inference_mode: true +ckpt_path: null +model: + class_path: aimless.lightning.waveform.WaveformSeparator + init_args: + model: + class_path: aimless.models.dpt.DPT + init_args: + channels: 2 + num_sources: 4 + num_filters: 64 + filter_size: 16 + filterbank_nonlinearity: ReLU + segment_size: 100 + segment_stride: 50 + num_dual_path_layers: 6 + num_attention_heads: 4 + lstm_hidden_size: 128 + transformer_dropout: 0.1 + transformer_nonlinearity: GELU + post_transformer_prelu: true + mask_nonlinearity: ReLU + criterion: + class_path: aimless.loss.time.L1Loss + transforms: + - class_path: aimless.augment.SpeedPerturb + init_args: + orig_freq: 44100 + speeds: + - 90 + - 100 + - 110 + p: 0.2 + - class_path: aimless.augment.RandomPitch + init_args: + semitones: + - -1 + - 1 + - 0 + - 1 + - 2 + p: 0.2 + targets: {vocals, drums, bass, other} +data: + class_path: data.lightning.MUSDB + init_args: + root: /import/c4dm-datasets-ext/musdb18hq/ + seq_duration: 10.0 + samples_per_track: 500 + transforms: + - class_path: data.augment.RandomGain + - class_path: data.augment.RandomFlipPhase + - class_path: data.augment.RandomSwapLR + - class_path: data.augment.LimitAug + init_args: + sample_rate: 44100 + random: true + random_track_mix: true + batch_size: 4 +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.0003 \ No newline at end of file From f712a35ddb486ffa1a2368a8a981a86e208332bd Mon Sep 17 00:00:00 2001 From: Matthew Rice Date: Wed, 1 Mar 2023 17:44:44 +0000 Subject: [PATCH 5/5] Updated DPTnet to latest cfg --- cfg/dpt.yaml | 53 ++++++++++++++++++++++++++++++++++++++++++------- environment.yml | 1 + setup.py | 2 +- 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/cfg/dpt.yaml b/cfg/dpt.yaml index 940cdaf..a8dc72f 100644 --- a/cfg/dpt.yaml +++ b/cfg/dpt.yaml @@ -56,7 +56,7 @@ ckpt_path: null model: class_path: aimless.lightning.waveform.WaveformSeparator init_args: - model: + model: class_path: aimless.models.dpt.DPT init_args: channels: 2 @@ -73,13 +73,13 @@ model: transformer_nonlinearity: GELU post_transformer_prelu: true mask_nonlinearity: ReLU - criterion: + criterion: class_path: aimless.loss.time.L1Loss transforms: - class_path: aimless.augment.SpeedPerturb init_args: orig_freq: 44100 - speeds: + speeds: - 90 - 100 - 110 @@ -101,12 +101,51 @@ data: seq_duration: 10.0 samples_per_track: 500 transforms: - - class_path: data.augment.RandomGain - - class_path: data.augment.RandomFlipPhase - - class_path: data.augment.RandomSwapLR - - class_path: data.augment.LimitAug + - class_path: data.augment.RandomParametricEQ init_args: sample_rate: 44100 + p: 0.7 + - class_path: data.augment.RandomPedalboardDistortion + init_args: + sample_rate: 44100 + p: 0.01 + - class_path: data.augment.RandomPedalboardDelay + init_args: + sample_rate: 44100 + p: 0.02 + - class_path: data.augment.RandomPedalboardChorus + init_args: + sample_rate: 44100 + p: 0.01 + - class_path: data.augment.RandomPedalboardPhaser + init_args: + sample_rate: 44100 + p: 0.01 + - class_path: data.augment.RandomPedalboardCompressor + init_args: + sample_rate: 44100 + p: 0.5 + - class_path: data.augment.RandomPedalboardReverb + init_args: + sample_rate: 44100 + p: 0.2 + - class_path: data.augment.RandomStereoWidener + init_args: + sample_rate: 44100 + p: 0.3 + - class_path: data.augment.RandomPedalboardLimiter + init_args: + sample_rate: 44100 + p: 0.1 + - class_path: data.augment.RandomVolumeAutomation + init_args: + sample_rate: 44100 + p: 0.1 + - class_path: data.augment.LoudnessNormalize + init_args: + sample_rate: 44100 + target_lufs_db: -32.0 + p: 1.0 random: true random_track_mix: true batch_size: 4 diff --git a/environment.yml b/environment.yml index 70d7a91..548072b 100644 --- a/environment.yml +++ b/environment.yml @@ -11,6 +11,7 @@ dependencies: - torchaudio - torchvision - cudatoolkit=11.7 + - einops - pip: - pytorch-lightning[extra] - torch-optimizer diff --git a/setup.py b/setup.py index e508766..fb38b61 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ author="Artificial Intelligence and Music League for Effective Source Separation", author_email="chin-yun.yu@qmul.ac.uk", packages=setuptools.find_packages(exclude=["tests", "tests.*", "data", "data.*"]), - install_requires=["torch", "pytorch-lightning", "torch_fftconv", "einops"], + install_requires=["torch", "pytorch-lightning", "torch_fftconv"], extras_require={"dev": ["pytest", "pytest-mock"]}, classifiers=[ "Programming Language :: Python :: 3",