当前位置:   article > 正文

cycle GAN

cycle GAN

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'#设置tensorflow的日志级别
from tensorflow.python.platform import build_info

import tensorflow as tf

# 列出所有物理GPU设备  
gpus = tf.config.list_physical_devices('GPU')  
if gpus:  
    # 如果有GPU,设置GPU资源使用率  
    try:  
        # 允许GPU内存按需增长  
        for gpu in gpus:  
            tf.config.experimental.set_memory_growth(gpu, True)  
        # 设置可见的GPU设备(这里实际上不需要,因为已经通过内存增长设置了每个GPU)  
        # tf.config.set_visible_devices(gpus, 'GPU')  
        print("GPU可用并已设置内存增长模式。")  
    except RuntimeError as e:  
        # 虚拟设备未就绪时可能无法设置GPU  
        print(f"设置GPU时发生错误: {e}")  
else:  
    # 如果没有GPU  
    print("没有检测到GPU设备。")

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
AUTOTUNE = tf.data.AUTOTUNE
# tf.data.AUTOTUNE 是一个特殊的值,它告诉TensorFlow的tf.data API自动选择适当的并行度。
# 当使用tf.data API来构建输入管道时,经常需要决定并行
# 处理数据的方式,以最大化数据加载和预处理的速度,同时不浪费计算资源。

# 加载训练数据  
def load_and_preprocess_image(image_path):  
    image = tf.io.read_file(image_path)  
    image = tf.image.decode_jpeg(image, channels=3)  
    image = tf.image.resize(image, IMAGE_SIZE)  
    image /= 255.0  # 归一化到[0, 1]  
    return image 

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

#改变图片大小
def resize(image, height, width):
  image = tf.image.resize(image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  return image

#定义随机裁剪方法
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
  return cropped_image

# 标准化 to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

def random_jitter(image):
  # 改变尺寸到 286x286
  image = resize(image, 286, 286)
  # 随机裁剪to 256 x 256 x 3
  image = random_crop(image)
  # 随机的水平翻转
  image = tf.image.random_flip_left_right(image)
  return image

def load(image_file):
    # 读取图片文件,并且解码转换成uint8
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    return image

def preprocess_image_train(image_file):#定义预处理训练图片的方法
    # print(image_file)
    image = load(image_file)
    image = random_jitter(image)
    image = normalize(image)
    return image

import matpl

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/400125
推荐阅读
相关标签
  

闽ICP备14008679号