diff --git a/keras_unet_collection/_model_unet_3plus_2d.py b/keras_unet_collection/_model_unet_3plus_2d.py index f956a7f..af018ad 100644 --- a/keras_unet_collection/_model_unet_3plus_2d.py +++ b/keras_unet_collection/_model_unet_3plus_2d.py @@ -241,10 +241,10 @@ def unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', X_decoder = unet_3plus_2d_backbone(IN, filter_num_down, filter_num_skip, filter_num_aggregate, stack_num_down=stack_num_down, stack_num_up=stack_num_up, activation=activation, batch_norm=batch_norm, pool=pool, unpool=unpool, name=name) + X_decoder = X_decoder[::-1] + if deep_supervision: - OUT_stack = [] - X_decoder = X_decoder[::-1] L_out = len(X_decoder) print('----------\ndeep_supervision = True\nnames of output tensors are listed as follows (the last one is the final output):') @@ -279,7 +279,7 @@ def unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', OUT_stack.append(X) OUT_stack.append( - CONV_output(X_decoder[-1], n_labels, kernel_size=3, + CONV_output(X_decoder[0], n_labels, kernel_size=3, activation=activation, name='{}_output_final'.format(name))) if output_activation: print('\t{}_output_final_activation'.format(name)) @@ -289,7 +289,7 @@ def unet_3plus_2d(input_size, n_labels, filter_num_down, filter_num_skip='auto', model = Model([IN,], OUT_stack) else: - OUT = CONV_output(X_decoder[-1], n_labels, kernel_size=3, + OUT = CONV_output(X_decoder[0], n_labels, kernel_size=3, activation=activation, name='{}_output_final'.format(name)) model = Model([IN,], [OUT,]) diff --git a/setup.py b/setup.py index 6d69156..d2e7d0e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name = "keras-unet-collection", - version = "0.0.7", + version = "0.0.7beta", author = "Yingkai (Kyle) Sha", author_email = "yingkaisha@gmail.com", description = "The Tensorflow, Keras implementation of U-net, U-net++, R2U-net, Attention U-net, ResUnet-a, U^2-Net, and UNET 3+.", diff --git a/user_guid.ipynb b/user_guid.ipynb index 73fdf22..af11a07 100644 --- a/user_guid.ipynb +++ b/user_guid.ipynb @@ -283,7 +283,7 @@ } ], "source": [ - "unet3plus = models.unet_3plus_2d((None, None, 3), n_labels=2, filter_num_down=[64, 128, 256, 512], \n", + "unet3plus = models.unet_3plus_2d((512, 512, 3), n_labels=2, filter_num_down=[64, 128, 256, 512], \n", " filter_num_skip=[64, 64, 64], filter_num_aggregate=256, \n", " stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',\n", " batch_norm=False, pool=True, unpool=False, deep_supervision=True, name='unet3plus')"