From 0df563e61fe8eb163918e944da75e84ccd5c07d4 Mon Sep 17 00:00:00 2001 From: dbuscombe-usgs Date: Tue, 26 Nov 2024 15:40:34 -0800 Subject: [PATCH] update do_seg --- doodleverse_utils/prediction_imports.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/doodleverse_utils/prediction_imports.py b/doodleverse_utils/prediction_imports.py index d6e2867..86817d8 100755 --- a/doodleverse_utils/prediction_imports.py +++ b/doodleverse_utils/prediction_imports.py @@ -275,7 +275,13 @@ def est_label_binary(image,M,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE,w,h): est_label = est_label + est_label2 + est_label3 + est_label4 # del est_label2, est_label3, est_label4 - est_label = est_label.numpy().astype('float32') + # est_label = est_label.numpy().astype('float32') + + if not isinstance(est_label, np.ndarray): + # If not, convert it to a numpy array + est_label = est_label.numpy() + # Now, convert to 'float32' + est_label = est_label.astype('float32') if MODEL=='segformer': est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze() @@ -396,7 +402,13 @@ def do_seg( est_label /= counter + 1 # est_label cannot be float16 so convert to float32 - est_label = est_label.numpy().astype('float32') + # est_label = est_label.numpy().astype('float32') + + if not isinstance(est_label, np.ndarray): + # If not, convert it to a numpy array + est_label = est_label.numpy() + # Now, convert to 'float32' + est_label = est_label.astype('float32') if MODEL=='segformer': est_label = resize(est_label, (1, NCLASSES, TARGET_SIZE[0],TARGET_SIZE[1]), preserve_range=True, clip=True).squeeze()