Skip to content

Commit

Permalink
Merge pull request #50 from roclark/support-multi-gpu-eval
Browse files Browse the repository at this point in the history
Support multi-GPU evaluation
  • Loading branch information
Tramac authored Jul 29, 2019
2 parents 2f0d6ee + a98dcec commit 7fbe397
Show file tree
Hide file tree
Showing 21 changed files with 74 additions and 24 deletions.
4 changes: 3 additions & 1 deletion core/models/bisenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def get_bisenet(dataset='citys', backbone='resnet18', pretrained=False, root='~/
model = BiSeNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('bisenet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('bisenet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/ccnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def get_ccnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root=
model = CCNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('ccnet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('ccnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/cgnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ def get_cgnet(dataset='citys', backbone='', pretrained=False, root='~/.torch/mod
model = CGNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('cgnet_%s' % (acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('cgnet_%s' % (acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/danet.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def get_danet(dataset='citys', backbone='resnet50', pretrained=False,
model = DANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('danet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('danet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def get_deeplabv3(dataset='pascal_voc', backbone='resnet50', pretrained=False, r
model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/deeplabv3_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=Fal
model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(
torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root)))
torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/denseaspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def get_denseaspp(dataset='citys', backbone='densenet121', pretrained=False,
model = DenseASPP(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('denseaspp_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('denseaspp_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/dfanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/mo
model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/dunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def get_dunet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
model = DUNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/encnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root
model = EncNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('encnet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('encnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/enet.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def get_enet(dataset='citys', backbone='', pretrained=False, root='~/.torch/mode
model = ENet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('enet_%s' % (acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('enet_%s' % (acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/espnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def get_espnet(dataset='pascal_voc', backbone='', pretrained=False, root='~/.tor
model = ESPNetV2(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('espnet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('espnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
12 changes: 9 additions & 3 deletions core/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def get_fcn32s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~
model = FCN32s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('fcn32s_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn32s_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand All @@ -178,7 +180,9 @@ def get_fcn16s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~
model = FCN16s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('fcn16s_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn16s_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand All @@ -195,7 +199,9 @@ def get_fcn8s(dataset='pascal_voc', backbone='vgg16', pretrained=False, root='~/
model = FCN8s(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('fcn8s_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn8s_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/fcnv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~
model = FCN(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('fcn_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('fcn_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/icnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def get_icnet(dataset='citys', backbone='resnet50', pretrained=False, root='~/.t
model = ICNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/lednet.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def get_lednet(dataset='citys', backbone='', pretrained=False, root='~/.torch/mo
model = LEDNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('lednet_%s' % (acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('lednet_%s' % (acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/ocnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,10 @@ def get_ocnet(dataset='citys', backbone='resnet50', oc_arch='base', pretrained=F
pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('%s_ocnet_%s_%s' % (
oc_arch, backbone, acronyms[dataset]), root=root)))
oc_arch, backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/psanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def get_psanet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root
model = PSANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/psanet_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def get_psanet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root
model = PSANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
4 changes: 3 additions & 1 deletion core/models/pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~
model = PSPNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
model.load_state_dict(torch.load(get_model_file('psp_%s_%s' % (backbone, acronyms[dataset]), root=root)))
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('psp_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model


Expand Down
10 changes: 8 additions & 2 deletions scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
sys.path.append(root_path)

import torch
import torch.nn as nn
import torch.utils.data as data
import torch.backends.cudnn as cudnn

Expand Down Expand Up @@ -43,10 +44,14 @@ def __init__(self, args):
pin_memory=True)

# create network
BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
aux=args.aux, pretrained=True, pretrained_base=False)
aux=args.aux, pretrained=True, pretrained_base=False,
local_rank=args.local_rank,
norm_layer=BatchNorm2d).to(self.device)
if args.distributed:
self.model = self.model.module
self.model = nn.parallel.DistributedDataParallel(self.model,
device_ids=[args.local_rank], output_device=args.local_rank)
self.model.to(self.device)

self.metric = SegmentationMetric(val_dataset.num_class)
Expand Down Expand Up @@ -107,3 +112,4 @@ def eval(self):

evaluator = Evaluator(args)
evaluator.eval()
torch.cuda.empty_cache()

0 comments on commit 7fbe397

Please sign in to comment.