Skip to content

Commit

Permalink
Merge pull request #156 from MaKaNu/master
Browse files Browse the repository at this point in the history
Fix the model test script
  • Loading branch information
Tramac authored Nov 2, 2020
2 parents 82db70c + 277a5b8 commit c77d349
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
16 changes: 15 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,25 @@ eval/
# overfitting test

# run result
/tests/runs
/runs

# model
/models/hrnet.py
/models/psanet_old.py
/scripts/debug.py

# nn
nn/sync_bn/
nn/sync_bn/

# venv
AwsmSemSegPytorch-env/
.vscode/launch.json
.vscode/settings.json


# builded files
core/nn/sync_bn/lib/gpu/build.ninja
core/nn/sync_bn/lib/gpu/.ninja_log
core/nn/sync_bn/lib/gpu/.ninja_deps

31 changes: 21 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
import argparse
import time
import os
import sys

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np

cur_path = os.path.abspath(os.path.dirname(__file__))
root_path = os.path.split(cur_path)[0]
sys.path.append(root_path)

from torchvision import transforms
from core.models.model_zoo import get_segmentation_model
from core.utils.loss import MixSoftmaxCrossEntropyLoss, EncNetLoss, ICNetLoss
Expand All @@ -20,17 +25,21 @@
def parse_args():
parser = argparse.ArgumentParser(description='Semantic Segmentation Overfitting Test')
# model
parser.add_argument('--model', type=str, default='ocnet',
choices=['fcn32s/fcn16s/fcn8s/fcn/psp/deeplabv3/danet/denseaspp/bisenet/encnet/dunet/icnet/enet/ocnet'],
parser.add_argument('--model', type=str, default='fcn32s',
choices=['fcn32s', 'fcn16s', 'fcn8s', 'fcn', 'psp',
'deeplabv3', 'danet', 'denseaspp', 'bisenet', 'encnet',
'dunet', 'icnet', 'enet', 'ocnet'],
help='model name (default: fcn32s)')
parser.add_argument('--backbone', type=str, default='resnet50',
choices=['vgg16/resnet18/resnet50/resnet101/resnet152/densenet121/161/169/201'],
parser.add_argument('--backbone', type=str, default='vgg16',
choices=['vgg16', 'resnet18', 'resnet50', 'resnet101',
'resnet152', 'densenet121', '161', '169', '201'],
help='backbone name (default: vgg16)')
parser.add_argument('--dataset', type=str, default='pascal_voc',
choices=['pascal_voc/pascal_aug/ade20k/citys/sbu'],
choices=['pascal_voc', 'pascal_aug', 'ade20k', 'citys',
'sbu'],
help='dataset name (default: pascal_voc)')
parser.add_argument('--epochs', type=int, default=100, metavar='N',
help='number of epochs to train (default: 60)')
help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
Expand Down Expand Up @@ -106,7 +115,9 @@ def train(self):
self.model.train()
start_time = time.time()
for epoch in range(self.args.epochs):
cur_lr = self.lr_scheduler(epoch)
self.lr_scheduler(self.optimizer, epoch)
cur_lr = self.lr_scheduler.learning_rate
# self.lr_scheduler(self.optimizer, epoch)
for param_group in self.optimizer.param_groups:
param_group['lr'] = cur_lr

Expand All @@ -117,17 +128,17 @@ def train(self):
loss = self.criterion(outputs, targets)

self.optimizer.zero_grad()
loss.backward()
loss['loss'].backward()
self.optimizer.step()

pred = torch.argmax(outputs[0], 1).cpu().data.numpy()
mask = get_color_pallete(pred.squeeze(0), self.args.dataset)
save_pred(self.args, epoch, mask)
hist, labeled, correct = hist_info(pred, targets.numpy(), 21)
hist, labeled, correct = hist_info(pred, targets.cpu().numpy(), 21)
_, mIoU, _, pixAcc = compute_score(hist, correct, labeled)

print('Epoch: [%2d/%2d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f || pixAcc: %.3f || mIoU: %.3f' % (
epoch, self.args.epochs, time.time() - start_time, cur_lr, loss.item(), pixAcc, mIoU))
epoch, self.args.epochs, time.time() - start_time, cur_lr, loss['loss'].item(), pixAcc, mIoU))


def save_pred(args, epoch, mask):
Expand Down

0 comments on commit c77d349

Please sign in to comment.