Skip to content

Commit

Permalink
Merge pull request #4 from olivesgatech/develop
Browse files Browse the repository at this point in the history
fixed visualization code for section_train.py
  • Loading branch information
yalaudah authored Feb 13, 2019
2 parents 9019537 + 9c71bdb commit 9a4f234
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions section_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from core.loader.data_loader import *
from core.metrics import runningScore
from core.models import get_model

from core.utils import np_to_tb

def split_train_val(args, per_val=0.1):
# create inline and crossline sections for training and validation:
Expand Down Expand Up @@ -201,7 +201,7 @@ def __iter__(self):
correct_label_decoded = train_set.decode_segmap(
np.squeeze(labels_original))
writer.add_image('train/original_label',
correct_label_decoded, epoch + 1)
np_to_tb(correct_label_decoded), epoch + 1)
out = F.softmax(outputs, dim=1)

# this returns the max. channel number:
Expand All @@ -212,7 +212,7 @@ def __iter__(self):
confidence, normalize=True, scale_each=True)

decoded = train_set.decode_segmap(np.squeeze(prediction))
writer.add_image('train/predicted', decoded, epoch + 1)
writer.add_image('train/predicted', np_to_tb(decoded), epoch + 1)
writer.add_image('train/confidence', tb_confidence, epoch + 1)

unary = outputs.cpu().detach()
Expand Down Expand Up @@ -276,7 +276,7 @@ def __iter__(self):
correct_label_decoded = train_set.decode_segmap(
np.squeeze(labels_original))
writer.add_image('val/original_label',
correct_label_decoded, epoch + 1)
np_to_tb(correct_label_decoded), epoch + 1)

out = F.softmax(outputs_val, dim=1)

Expand All @@ -287,11 +287,9 @@ def __iter__(self):
tb_confidence = vutils.make_grid(
confidence, normalize=True, scale_each=True)

decoded = train_set.decode_segmap(
np.squeeze(prediction))
writer.add_image('val/predicted', decoded, epoch + 1)
writer.add_image('val/confidence',
tb_confidence, epoch + 1)
decoded = train_set.decode_segmap(np.squeeze(prediction))
writer.add_image('val/predicted', np_to_tb(decoded), epoch + 1)
writer.add_image('val/confidence',tb_confidence, epoch + 1)

unary = outputs.cpu().detach()
unary_max, unary_min = torch.max(
Expand Down Expand Up @@ -350,7 +348,7 @@ def __iter__(self):
help='Path to previous saved model to restart from')
parser.add_argument('--clip', nargs='?', type=float, default=0.1,
help='Max norm of the gradients if clipping. Set to zero to disable. ')
parser.add_argument('--per_val', nargs='?', type=float, default=0,
parser.add_argument('--per_val', nargs='?', type=float, default=0.2,
help='percentage of the training data for validation')
parser.add_argument('--pretrained', nargs='?', type=bool, default=False,
help='Pretrained models not supported. Keep as False for now.')
Expand Down

0 comments on commit 9a4f234

Please sign in to comment.