Skip to content

Commit

Permalink
Bugfix: input array replaced to a plain tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanKuchin committed Aug 12, 2024
1 parent 07dd305 commit 4bc65d8
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions tools/craft_network/att_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def res_block(filters, input_shape, kernel_size, apply_batchnorm, apply_instance
if (apply_dropout):
x = tf.keras.layers.Dropout(0.5)(x)

return tf.keras.models.Model(inputs = [input_layer], outputs = x)
return tf.keras.models.Model(inputs = input_layer, outputs = x)

def double_conv(filters, input_shape, kernel_size, apply_batchnorm, apply_instancenorm, apply_dropout=False):
model = tf.keras.models.Sequential()
Expand Down Expand Up @@ -72,12 +72,7 @@ def double_conv(filters, input_shape, kernel_size, apply_batchnorm, apply_instan

def get_gating_base(filters, apply_batchnorm = True):
__model = tf.keras.models.Sequential(name = "gating_base")

for _ in range(2):
__model.add(tf.keras.layers.Conv3D(filters, kernel_size = 1, padding = "same", kernel_initializer='he_uniform'))
if apply_batchnorm:
__model.add(tf.keras.layers.BatchNormalization(momentum = config.BATCH_NORM_MOMENTUM))
__model.add(tf.keras.layers.LeakyReLU())
__model.add(tf.keras.layers.Conv3D(filters, kernel_size = 1, padding = "same", kernel_initializer='he_uniform'))

return __model

Expand Down

0 comments on commit 4bc65d8

Please sign in to comment.