Skip to content

Commit

Permalink
fixed bugs in guided_choices and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
razor1179 committed May 13, 2019
1 parent 3bff7fd commit 9c40dbe
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 97 deletions.
160 changes: 76 additions & 84 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions cfg/TIMIT_CGS/TIMIT_LSTM_fmllr_ghcgs.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ cfg_proto_chunk = proto/global_chunk.proto
[exp]
cmd =
run_nn_script = run_nn
out_folder = exp/TIMIT_LSTM_fmllr_test2_ghcgs_onlyFinal_25d32b_75d4b
out_folder = exp/TIMIT_LSTM_fmllr_test2_gcgs_onlyFirst_81p25d32b
seed = 22341
use_cuda = True
multi_gpu = False
save_gpumem = False
n_epochs_tr = 8
apply_guided_ep = 8
apply_guided_ep = 2

[dataset1]
data_name = TIMIT_tr
Expand Down Expand Up @@ -160,10 +160,10 @@ out_folder =
lstm_hcgs = False
guided_hcgs = True
apply_guided_hcgs = False
hcgsx_block = 32,4
hcgsx_sparse = 25,75
hcgsh_block = 32,4
hcgsh_sparse = 25,75
hcgsx_block = 32
hcgsx_sparse = 81.25
hcgsh_block = 32
hcgsh_sparse = 81.25
lstm_quant = False
param_quant = 6,6
lstm_quant_inp = False
Expand Down Expand Up @@ -200,8 +200,8 @@ out_folder =
mlp_hcgs = False
guided_hcgs = True
apply_guided_hcgs = False
hcgs_block = 32,4
hcgs_sparse = 25,75
hcgs_block = 32
hcgs_sparse = 81.25
mlp_quant = False
param_quant = 5
mlp_quant_inp = False
Expand Down
6 changes: 3 additions & 3 deletions cfg/TIMIT_CGS/TIMIT_LSTM_fmllr_ghcgs_L1.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ cfg_proto_chunk = proto/global_chunk.proto
[exp]
cmd =
run_nn_script = run_nn
out_folder = exp/TIMIT_LSTM_fmllr_test_ghcgs_l1_onlyFinal_25d32b_75d4b
seed = 2234
out_folder = exp/TIMIT_LSTM_fmllr_test2_ghcgs_l1_onlyFirst_25d32b_75d4b
seed = 22341
use_cuda = True
multi_gpu = False
save_gpumem = False
n_epochs_tr = 8
apply_guided_ep = 8
apply_guided_ep = 2

[dataset1]
data_name = TIMIT_tr
Expand Down
2 changes: 1 addition & 1 deletion guided_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def guided_array_rows (w_mat, n_blk, n_blk_sels, blk_size):
w_mat2 = w_mat[: ,x:x + c1]
temp = torch.randn(1, 1, r, c1)
temp[0, 0, :, :] = w_mat2
if r == blk_size:
if r == blk_size and c1 >= blk_size:
avg = torch.nn.AvgPool2d(blk_size, c1)
else:
avg = torch.nn.AvgPool2d((r, c1), c1)
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,7 +1898,7 @@ def forward_model(fea_dict, lab_dict, arch_dict, model, nns, costs, inp, inp_out
for line in model:
[out_name, operation, inp1, inp2] = list(re.findall(pattern, line)[0])

if out_name == 'loss_gl':
if out_name[0:7] == 'loss_gl':
pattern2 = '(.*)=(.*)\((.*),(.*),(.*)\)'
[out_name, operation, inp1, inp2, inp3] = list(re.findall(pattern2, line)[0])

Expand Down

0 comments on commit 9c40dbe

Please sign in to comment.