Skip to content

Commit

Permalink
massive refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
joewong00 committed Apr 8, 2022
1 parent df14221 commit 4a094b6
Show file tree
Hide file tree
Showing 15 changed files with 841 additions and 1,020 deletions.
Binary file added Training info.xlsx
Binary file not shown.
456 changes: 0 additions & 456 deletions dataloading.ipynb

This file was deleted.

Binary file added images/Training Loss36.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
639 changes: 639 additions & 0 deletions main.ipynb

Large diffs are not rendered by default.

Binary file modified output/Mask_MRI1_T2.nii.gz
Binary file not shown.
Binary file modified output/image.pdf
Binary file not shown.
Binary file modified output/result.pdf
Binary file not shown.
31 changes: 2 additions & 29 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,9 @@
from residual3dunet.res3dunetmodel import ResidualUNet3D
from torch.nn import DataParallel
from utils.segmentation_statistics import SegmentationStatistics
from utils.utils import load_checkpoint, read_data_as_numpy, add_channel, to_depth_first, numpy_to_nii, visualize2d, to_depth_last, plot_overlapped, preprocess
from utils.utils import load_checkpoint, read_data_as_numpy, numpy_to_nii, visualize2d, plot_sidebyside, plot_overlapped, preprocess, predict


def predict(model,input,threshold,device):

model.eval()

input = to_depth_first(input)

if len(input.shape) == 3:
input = add_channel(input)

# Add batch dimension
input = input.unsqueeze(0)
input = input.to(device=device, dtype=torch.float32)

# Disable grad
with torch.no_grad():

output = model(input)
preds = (output > threshold).float()

# Squeeze channel and batch dimension
preds = torch.squeeze(preds)

# Convert to numpy
preds = preds.cpu().numpy()

return preds

def get_args():
# Test settings
parser = argparse.ArgumentParser(description='Predict masks from input images')
Expand Down Expand Up @@ -76,7 +49,6 @@ def main():
model = DataParallel(model)

load_checkpoint(args.model, model ,device=device)
# model.load_state_dict(torch.load(args.model, map_location=device))

logging.info('Model loaded!')
logging.info(f'\nPredicting image {filename} ...')
Expand Down Expand Up @@ -104,6 +76,7 @@ def main():
target = preprocess(read_data_as_numpy(args.mask),rotate=True, to_tensor=False)

plot_overlapped(data, prediction, target)
plot_sidebyside(data, prediction, target)

prediction = prediction.astype(bool)
target = target.astype(bool)
Expand Down
66 changes: 66 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import argparse
import matplotlib.pyplot as plt
from dataloader import MRIDataset
from residual3dunet.model import ResidualUNet3D, UNet3D
from torch.utils.data import Dataset, DataLoader
from torch.nn import DataParallel
import numpy as np
import torch
import torchvision.transforms as T
from utils.segmentation_statistics import SegmentationStatistics
from utils.evaluate import evaluate
from utils.utils import compute_average, load_checkpoint
import os


def get_args():
# Test settings
parser = argparse.ArgumentParser(description='Evaluate using test loader')
parser.add_argument('--network', '-u', default='Unet3D', help='Specify the network (Unet3D / ResidualUnet3D)')
parser.add_argument('--model', '-m', default='model.pt', metavar='FILE', help='Specify the file in which the model is stored')
parser.add_argument('--batch-size', type=int, default=1, metavar='N',help='input batch size for testing (default: 64)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA testing')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5, help='Minimum probability value to consider a mask pixel white')

return parser.parse_args()


def main():

args = get_args()

use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

model = ResidualUNet3D(in_channels=1, out_channels=1, testing=True).to(device)

# If using multiple gpu
# Specify network
if args.network.casefold() == "unet3d":
model = UNet3D(in_channels=1, out_channels=1, testing=True).to(device)

else:
model = ResidualUNet3D(in_channels=1, out_channels=1, testing=True).to(device)

# If using multiple gpu
if torch.cuda.device_count() > 1 and use_cuda:
model = DataParallel(model)

load_checkpoint(args.model, model ,device=device)

test_kwargs = {'batch_size': args.batch_size}

if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
test_kwargs.update(cuda_kwargs)

testdataset = MRIDataset(train=False, transform=T.ToTensor())
test_loader = DataLoader(dataset=testdataset, **test_kwargs)

evaluate(model, test_loader, device, args.mask_threshold, show_stat=True)


if __name__ == '__main__':
main()
47 changes: 25 additions & 22 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from torch.optim import Adam
from dataloader import MRIDataset
from evaluate import evaluate
from residual3dunet.model import ResidualUNet3D, UNet3D
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import StepLR
Expand All @@ -20,10 +19,10 @@ def train(args, model, device, train_loader, optimizer, epoch, criterion):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):

assert data.shape[1] == model.in_channels, \
f'Network has been defined with {model.n_channels} input channels, ' \
f'but loaded images have {data.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
# assert data.shape[1] == model.in_channels, \
# f'Network has been defined with {model.n_channels} input channels, ' \
# f'but loaded images have {data.shape[1]} channels. Please check that ' \
# 'the images are loaded correctly.'

data, target = data.float().to(device), target.float().to(device)

Expand Down Expand Up @@ -74,7 +73,6 @@ def get_args():
parser = argparse.ArgumentParser(description='PyTorch 3D Segmentation')
parser.add_argument('--network', '-u', default='Unet3D', help='Specify the network (Unet3D / ResidualUnet3D)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1, metavar='N',help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=3, metavar='N',help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=2.5e-4, metavar='LR', help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.1, metavar='M',help='Learning rate step gamma (default: 0.7)')
Expand All @@ -90,6 +88,8 @@ def get_args():

def main():

# ------------------------------------ Network Config ------------------------------------

args = get_args()

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
Expand All @@ -112,7 +112,6 @@ def main():
# Specify network
if args.network.casefold() == "unet3d":
model = UNet3D(in_channels=1, out_channels=1).to(device)

else:
model = ResidualUNet3D(in_channels=1, out_channels=1).to(device)

Expand All @@ -124,37 +123,40 @@ def main():
if args.checkpoint:
load_checkpoint(args.checkpoint, model, device=device)

logging.info(f'Network:\n'
f'\t{model.in_channels} input channels\n'
f'\t{model.out_channels} output channels (classes)\n')
# Hyperparameters
optimizer = Adam(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=50, gamma=args.gamma)
loss = DiceBCELoss()


# logging.info(f'Network:\n'
# f'\t{model.in_channels} input channels\n'
# f'\t{model.out_channels} output channels (classes)\n')

# Data Loading

# ------------------------------------ Data Loading ------------------------------------

# Train data transformation
transformation = T.Compose([T.ToTensor(),
T.RandomHorizontalFlip(),
T.RandomRotation(90),
T.RandomCrop((240,240), padding=50, pad_if_needed=True)
])


traindataset = MRIDataset(train=True, transform=transformation, elastic=True)
# testdataset = MRIDataset(train=False, transform=T.ToTensor())

# Train validation set splitting 90/10
train_set, val_set = random_split(traindataset, [int(len(traindataset)*0.9),int(len(traindataset)*0.1)])

train_loader = DataLoader(dataset=train_set, **train_kwargs)
val_loader = DataLoader(dataset=val_set, **train_kwargs)
# test_loader = DataLoader(dataset=testdataset, **test_kwargs)

# Hyperparameters
optimizer = Adam(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=50, gamma=args.gamma)
loss = DiceBCELoss()
# ------------------------------------ Training Loop ------------------------------------

# Validation Loss
minvalidation = 1
loss_train = []
loss_val = []
min_dice = 1

logging.info(f'''Starting training:
Network: {args.network}
Expand All @@ -169,9 +171,8 @@ def main():
# Training process
for epoch in range(1, args.epochs + 1):

# Training
trainloss = train(args, model, device, train_loader, optimizer, epoch, loss)
valloss = evaluate(model, val_loader, device, loss)
valloss = test(model, device, val_loader, epoch, loss)

print('Average train loss: {}'.format(trainloss))
print('Average test loss: {}'.format(valloss))
Expand All @@ -181,12 +182,14 @@ def main():

scheduler.step()

# Save the best validated model
if valloss < minvalidation and args.save_model:
minvalidation = valloss

save_model(model, is_best=True, checkpoint_dir='checkpoints')

# plot_train_loss(loss_train, loss_val)
# Plot training loss graph
plot_train_loss(loss_train, loss_val)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion utils/elastic_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates
from utils import convert_to_numpy, add_channel, to_depth_last
from utils.utils import convert_to_numpy, add_channel, to_depth_last

import numbers
import numpy as np
Expand Down
15 changes: 12 additions & 3 deletions evaluate.py → utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from utils.utils import compute_average


def evaluate(net, dataloader, device, show_stat=False):
def evaluate(net, dataloader, device, threshold, show_stat=False):
"""Evaluate the model using test data using different evaluation metrics (check utils/segmentation_statistics.py)
Args:
net (torch.nn.Module): Trained model
dataloader (DataLoader): Test data loader
device (torch.device): Device (cpu or cuda)
show stat (bool): Show the statistical result based on dataset cohort
Returns:
stats (dict): the average evaluation metrics
"""

stats = []
net.eval()
Expand All @@ -17,7 +26,7 @@ def evaluate(net, dataloader, device, show_stat=False):
data, target = data.float().to(device), target.float().to(device)
output = net(data)

preds = (F.sigmoid(output) > 0.5).float()
preds = (output > threshold).float()

# Convert to numpy boolean
preds = preds.cpu().numpy()
Expand All @@ -27,7 +36,7 @@ def evaluate(net, dataloader, device, show_stat=False):

batch, channel, depth, width, height = preds.shape

for idx in len(batch):
for idx in range(batch):
stat = SegmentationStatistics(preds[idx,0,:,:,:], target[idx,0,:,:,:], (3,2,1))
stats.append(stat.to_dict())

Expand Down
2 changes: 1 addition & 1 deletion utils/transform_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import torchvision.transforms as T

from elastic_transform import RandomElastic
from utils.elastic_transform import RandomElastic


class Transform_3D_Mask_Label(object):
Expand Down
Loading

0 comments on commit 4a094b6

Please sign in to comment.