From fb81e3b5392f109c863673bb5bc49889fdc812a9 Mon Sep 17 00:00:00 2001 From: iback Date: Mon, 9 Sep 2024 08:25:06 +0000 Subject: [PATCH] upgrades 2nd phase to bigger and channel-wise model. More robust to abberations. --- spineps/Unet3D/pl_unet.py | 3 ++- spineps/seg_model.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/spineps/Unet3D/pl_unet.py b/spineps/Unet3D/pl_unet.py index e49e96e..ed03979 100755 --- a/spineps/Unet3D/pl_unet.py +++ b/spineps/Unet3D/pl_unet.py @@ -16,7 +16,7 @@ def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> N nclass = Unet3D dim_mults = (1, 2, 4, 8) - dim = 8 + dim = 16 # 16 # if opt.high_res: # dim = 16 @@ -26,6 +26,7 @@ def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> N dim=dim, dim_mults=dim_mults, out_dim=4, + channels=10, # 10, ) self.opt = opt diff --git a/spineps/seg_model.py b/spineps/seg_model.py index 4abd866..fdb3a39 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.nn.functional as F # noqa: N812 from torch import from_numpy from TPTBox import NII, ZOOMS, Image_Reference, Log_Type, Logger, No_Logger, to_nii from typing_extensions import Self @@ -340,11 +341,31 @@ def run( input_nii = input_nii[0] arr = input_nii.get_seg_array().astype(np.int16) - target = from_numpy(arr).to(torch.float32) - target /= 9 - target = target.unsqueeze(0) - target = target.unsqueeze(0) - logits = self.predictor.forward(target.to(self.device)) + target = from_numpy(arr) + + target[target == 26] = 0 + + do_backup = False + # channel-wise + try: + targetc = target.to(torch.int64) + targetc = F.one_hot(targetc, num_classes=10) + targetc = targetc.permute(3, 0, 1, 2) + targetc = targetc.unsqueeze(0) + targetc = targetc.to(torch.float32) + logits = self.predictor.forward(targetc.to(self.device)) + # + except Exception: + # print("Channel-wise model failed, try legacy version") + do_backup = True + # + if do_backup: + target = target.to(torch.float32) + target /= 9 + target = target.unsqueeze(0) + target = target.unsqueeze(0) + logits = self.predictor.forward(target.to(self.device)) + # pred_x = self.predictor.softmax(logits) _, pred_cls = torch.max(pred_x, 1) del logits