Skip to content

Commit

Permalink
Merge pull request #108 from tsurumeso/develop
Browse files Browse the repository at this point in the history
Parameterize model size
  • Loading branch information
tsurumeso authored Jun 7, 2022
2 parents bcfe8b6 + a48ff29 commit 0efc190
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 16 deletions.
6 changes: 3 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def main():

print('loading model...', end=' ')
device = torch.device('cpu')
model = nets.CascadedNet(args.n_fft)
model = nets.CascadedNet(args.n_fft, 32, 128)
model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
if torch.cuda.is_available() and args.gpu >= 0:
device = torch.device('cuda:{}'.format(args.gpu))
Expand All @@ -150,10 +150,10 @@ def main():
y_spec, v_spec = sp.separate_tta(X_spec)
else:
y_spec, v_spec = sp.separate(X_spec)

print('validating output directory...', end=' ')
output_dir = args.output_dir
if output_dir != "": # modifies output_dir if theres an arg specified
if output_dir != "": # modifies output_dir if theres an arg specified
output_dir = output_dir.rstrip('/') + '/'
os.makedirs(output_dir, exist_ok=True)
print('done')
Expand Down
8 changes: 6 additions & 2 deletions lib/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False
nn.AdaptiveAvgPool2d((1, None)),
Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
)
self.conv2 = Conv2DBNActiv(nin, nout, 1, 1, 0, activ=activ)
self.conv2 = Conv2DBNActiv(
nin, nout, 1, 1, 0, activ=activ
)
self.conv3 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[0], dilations[0], activ=activ
)
Expand All @@ -109,7 +111,9 @@ def __init__(self, nin, nout, dilations=(4, 8, 12), activ=nn.ReLU, dropout=False
self.conv5 = Conv2DBNActiv(
nin, nout, 3, 1, dilations[2], dilations[2], activ=activ
)
self.bottleneck = Conv2DBNActiv(nout * 5, nout, 1, 1, 0, activ=activ)
self.bottleneck = Conv2DBNActiv(
nout * 5, nout, 1, 1, 0, activ=activ
)
self.dropout = nn.Dropout2d(0.1) if dropout else None

def forward(self, x):
Expand Down
26 changes: 16 additions & 10 deletions lib/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,35 @@ def __call__(self, x):

class CascadedNet(nn.Module):

def __init__(self, n_fft):
def __init__(self, n_fft, nout=32, nout_lstm=128):
super(CascadedNet, self).__init__()
self.max_bin = n_fft // 2
self.output_bin = n_fft // 2 + 1
self.nin_lstm = self.max_bin // 2
self.offset = 64

self.stg1_low_band_net = nn.Sequential(
BaseNet(2, 16, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(16, 8, 1, 1, 0)
BaseNet(2, nout // 2, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout // 2, nout // 4, 1, 1, 0)
)
self.stg1_high_band_net = BaseNet(
2, nout // 4, self.nin_lstm // 2, nout_lstm // 2
)
self.stg1_high_band_net = BaseNet(2, 8, self.nin_lstm // 2, 64)

self.stg2_low_band_net = nn.Sequential(
BaseNet(10, 32, self.nin_lstm // 2, 128),
layers.Conv2DBNActiv(32, 16, 1, 1, 0)
BaseNet(nout // 4 + 2, nout, self.nin_lstm // 2, nout_lstm),
layers.Conv2DBNActiv(nout, nout // 2, 1, 1, 0)
)
self.stg2_high_band_net = BaseNet(
nout // 4 + 2, nout // 2, self.nin_lstm // 2, nout_lstm // 2
)
self.stg2_high_band_net = BaseNet(10, 16, self.nin_lstm // 2, 64)

self.stg3_full_band_net = BaseNet(26, 32, self.nin_lstm, 128)
self.stg3_full_band_net = BaseNet(
3 * nout // 4 + 2, nout, self.nin_lstm, nout_lstm
)

self.out = nn.Conv2d(32, 2, 1, bias=False)
self.aux_out = nn.Conv2d(24, 2, 1, bias=False)
self.out = nn.Conv2d(nout, 2, 1, bias=False)
self.aux_out = nn.Conv2d(3 * nout // 4, 2, 1, bias=False)

def forward(self, x):
x = x[:, :, :self.max_bin]
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def main():
logger.info('{} {} {}'.format(i + 1, os.path.basename(X_fname), os.path.basename(y_fname)))

device = torch.device('cpu')
model = nets.CascadedNet(args.n_fft)
model = nets.CascadedNet(args.n_fft, 32, 128)
if args.pretrained_model is not None:
model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
if torch.cuda.is_available() and args.gpu >= 0:
Expand Down

0 comments on commit 0efc190

Please sign in to comment.