赞
踩
tf2模型代码:
- import tensorflow.keras.backend as K
- from tensorflow.keras.models import Model
- from tensorflow.keras import Input
- from tensorflow.keras.layers import Conv2D, PReLU, UpSampling2D, concatenate , Reshape, Dense, Permute, MaxPool2D
- from tensorflow.keras.layers import GlobalAveragePooling2D, Activation, add, GaussianNoise, BatchNormalization, multiply
- from tensorflow.keras.optimizers import SGD
- from loss import custom_loss
- K.set_image_data_format("channels_last")
-
-
-
- def unet_model(input_shape, modified_unet=True, learning_rate=0.01, start_channel=64,
- number_of_levels=3, inc_rate=2, output_channels=4, saved_model_dir=None):
- """
- Builds UNet model
-
- Parameters
- ----------
- input_shape : tuple
- Shape of the input data (height, width, channel)
- modified_unet : bool
- Whether to use modified UNet or the original UNet
- learning_rate : float
- Learning rat
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。