diff --git a/python/train.py b/python/train.py index f379bc3..3b59994 100644 --- a/python/train.py +++ b/python/train.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from fcn import VGGNet, FCN32s, FCN16s, FCN8s, FCNs -from Cityscapes_loader import CityscapesDataset +from Cityscapes_loader import CityScapesDataset from CamVid_loader import CamVidDataset from matplotlib import pyplot as plt @@ -34,7 +34,7 @@ if sys.argv[1] == 'CamVid': root_dir = "CamVid/" -else +else: root_dir = "CityScapes/" train_file = os.path.join(root_dir, "train.csv") val_file = os.path.join(root_dir, "val.csv") @@ -57,7 +57,7 @@ if sys.argv[1] == 'CamVid': val_data = CamVidDataset(csv_file=val_file, phase='val', flip_rate=0) else: - val_data = CityscapesDataset(csv_file=val_file, phase='val', flip_rate=0) + val_data = CityScapesDataset(csv_file=val_file, phase='val', flip_rate=0) val_loader = DataLoader(val_data, batch_size=1, num_workers=8) vgg_model = VGGNet(requires_grad=True, remove_fc=True)