当前位置:   article > 正文

openai/gym中的图像预处理_opengait代码详解

opengait代码详解

openai/gym中的图像预处理

之前读强化学习的文章,提到对gym模拟器的图像进行预处理,每4张图片生成一个(84,84,4)的tensor,但是在网上没有搜到具体的实现,因此写一个预处理的函数,用到了cv2,这个模块需要安装opencv-python这个库

步骤如下:

  1. resize:4个图片,每个图片的shape为(210,160,3),resize为(84,84,3)
  2. 灰化:由(84,84,3)变为(84,84,1)
  3. 归一化:数据类型从[0,255]的uint8,变为[0.0,1.0]的float32
  4. concat:将4个图片连接为1个(84,84,4)的tensor

代码如下:

import matplotlib.pyplot as plt
import gym
import numpy as np
import cv2

# 输入 N个3通道的图片array
# 输出:一个array 形状 (84 84 N)
# 步骤: 1. resize ==>(84 84 3)[uint 0-255]
#       2. gray   ==> (84 84 1) [uint 0-255]
#       3. norm   ==> (84 84 1) [float32 0.0-1.0]
#       4. concat ===>(84 84 N) [float32 0.0-1.0]
def imgbuffer_process(imgbuffer, out_shape = (84, 84)):
    img_list = []
    for img in imgbuffer:
        tmp = cv2.resize(src=img, dsize=out_shape)
        tmp = cv2.cvtColor(tmp, cv2.COLOR_BGR2GRAY)
        ## 需要将数据类型转为32F
        tmp = cv2.normalize(tmp, tmp, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        # 扩充一个维度
        tmp = np.expand_dims(tmp, len(tmp.shape))
        img_list.append(tmp)
    ret =  np.concatenate(tuple(img_list), axis=2)
    #print('ret_shape = ' + str(ret.shape))
    return ret

def test():
    env = gym.make('Breakout-v4')
    env.seed(1)  # reproducible
    # env = env.unwrapped
    N_F = env.observation_space.shape[0]  # 状态空间的维度
    N_A = env.action_space.n  # 动作空间的维度

    img_buffer = []
    img_buffer_size = 4
    s = env.reset()
    max_loop = 100000

    for i in range(max_loop):
        a = np.random.randint(0, N_A - 1)
        s_, r, done, info = env.step(a)
        env.render()

        if len(img_buffer) < img_buffer_size:
            img_buffer.append(s_)
            continue
        else:
            img_buffer.pop(0)
            img_buffer.append(s_)

        img_input = imgbuffer_process(img_buffer)
        print('img_input_shape = ' + str(img_input.shape))
        plt.subplot(2, 2, 1)
        plt.imshow(np.uint8(img_input[:, :, 0] * 255), cmap='gray')
        plt.subplot(2, 2, 2)
        plt.imshow(np.uint8(img_input[:, :, 1] * 255), cmap='gray')
        plt.subplot(2, 2, 3)
        plt.imshow(np.uint8(img_input[:, :, 2] * 255), cmap='gray')
        plt.subplot(2, 2, 4)
        plt.imshow(np.uint8(img_input[:, :, 3] * 255), cmap='gray')
        plt.show()

if __name__ == '__main__':
    test()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

运行截图

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号