Skip to content

Commit

Permalink
Merge pull request #29 from Hendrik-code/2ndphase
Browse files Browse the repository at this point in the history
upgrades 2nd phase to bigger and channel-wise model. More robust to abberations
  • Loading branch information
Hendrik-code authored Sep 9, 2024
2 parents 5a56cbc + 9f211c8 commit 002245a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
3 changes: 2 additions & 1 deletion spineps/Unet3D/pl_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> N
nclass = Unet3D

dim_mults = (1, 2, 4, 8)
dim = 8
dim = 16 # 16

# if opt.high_res:
# dim = 16
Expand All @@ -26,6 +26,7 @@ def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> N
dim=dim,
dim_mults=dim_mults,
out_dim=4,
channels=10, # 10,
)

self.opt = opt
Expand Down
31 changes: 26 additions & 5 deletions spineps/seg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
from torch import from_numpy
from TPTBox import NII, ZOOMS, Image_Reference, Log_Type, Logger, No_Logger, to_nii
from typing_extensions import Self
Expand Down Expand Up @@ -340,11 +341,31 @@ def run(
input_nii = input_nii[0]

arr = input_nii.get_seg_array().astype(np.int16)
target = from_numpy(arr).to(torch.float32)
target /= 9
target = target.unsqueeze(0)
target = target.unsqueeze(0)
logits = self.predictor.forward(target.to(self.device))
target = from_numpy(arr)

target[target == 26] = 0

do_backup = False
# channel-wise
try:
targetc = target.to(torch.int64)
targetc = F.one_hot(targetc, num_classes=10)
targetc = targetc.permute(3, 0, 1, 2)
targetc = targetc.unsqueeze(0)
targetc = targetc.to(torch.float32)
logits = self.predictor.forward(targetc.to(self.device))
#
except Exception:
# print("Channel-wise model failed, try legacy version")
do_backup = True
#
if do_backup:
target = target.to(torch.float32)
target /= 9
target = target.unsqueeze(0)
target = target.unsqueeze(0)
logits = self.predictor.forward(target.to(self.device))
#
pred_x = self.predictor.softmax(logits)
_, pred_cls = torch.max(pred_x, 1)
del logits
Expand Down

0 comments on commit 002245a

Please sign in to comment.