当前位置:   article > 正文

强化学习玩flappy_bird_pygame强化学习

pygame强化学习

强化学习玩flappy_bird(代码解析)

游戏地址:https://flappybird.io/

该游戏的规则是:

  • 点击屏幕则小鸟立即获得向上速度。

  • 不点击屏幕则小鸟受重力加速度影响逐渐掉落。

  • 小鸟碰到地面会死亡,碰到水管会死亡。(碰到天花板不会死亡)

  • 小鸟通过水管会得分。

    img

    具体的网络结构如图所示,网络架构是拿到游戏状态(每个样本维度是 80 * 80 * 4),然后卷积(输出维度 20 * 20 * 32)、池化(输出 10 * 10 * 32)、卷积(输出 5 * 5 * 64)、卷积(输出 5 * 5 * 64)、reshape(1600)、全连接层(512)、输出层(2)

一、flappy_bird_utils.py

"""
游戏素材加载
"""
import pygame
import sys
import os

assets_dir = os.path.dirname(__file__)

def load():
    # 小鸟挥动翅膀的3个造型
    PLAYER_PATH = (
        assets_dir + '/assets/sprites/redbird-upflap.png',
        assets_dir + '/assets/sprites/redbird-midflap.png',
        assets_dir + '/assets/sprites/redbird-downflap.png'
    )

    # 游戏背景图,纯黑色是为了训练降低干扰
    BACKGROUND_PATH = assets_dir + '/assets/sprites/background-black.png'

    # 水管图片
    PIPE_PATH = assets_dir + '/assets/sprites/pipe-green.png'

    IMAGES, SOUNDS, HITMASKS = {}, {}, {}
    #初始化了三个空字典:IMAGES用于存储加载的图片资源,SOUNDS用于存储加载的声音资源,HITMASKS用于存储碰撞掩码(用于检测游戏中的碰撞)

    # 加载数字0~9的图片,类型是Surface图像
    #使用convert_alpha()方法将图片转换为带有透明度的格式
    IMAGES['numbers'] = (
        pygame.image.load(assets_dir + '/assets/sprites/0.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/1.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/2.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/3.png').convert_alpha(),    # convert/conver_alpha是为了将图片转成绘制用的像素格式,提高绘制效率
        pygame.image.load(assets_dir + '/assets/sprites/4.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/5.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/6.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/7.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/8.png').convert_alpha(),
        pygame.image.load(assets_dir + '/assets/sprites/9.png').convert_alpha()
    )

    # 地面图片
    IMAGES['base'] = pygame.image.load(assets_dir + '/assets/sprites/base.png').convert_alpha()


    #根据操作系统类型,设置声音文件的扩展名。Windows系统使用.wav,其他系统使用.ogg
    if 'win' in sys.platform:
        soundExt = '.wav'
    else:
        soundExt = '.ogg'

    # 各种Sound对象
    #加载各种游戏音效,并将它们存储在SOUNDS字典中
    SOUNDS['die']    = pygame.mixer.Sound(assets_dir + '/assets/audio/die' + soundExt)
    SOUNDS['hit']    = pygame.mixer.Sound(assets_dir + '/assets/audio/hit' + soundExt)
    SOUNDS['point']  = pygame.mixer.Sound(assets_dir + '/assets/audio/point' + soundExt)
    SOUNDS['swoosh'] = pygame.mixer.Sound(assets_dir + '/assets/audio/swoosh' + soundExt)
    SOUNDS['wing']   = pygame.mixer.Sound(assets_dir + '/assets/audio/wing' + soundExt)

    # 加载背景图片
    IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()

    # 加载小鸟的3个姿态
    IMAGES['player'] = (
        pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
        pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
        pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
    )

    # 加载水管图片,并使用rotate()方法将其旋转180度以创建上方的水管图片,然后将这两个图片存储在IMAGES字典中
    IMAGES['pipe'] = (
        pygame.transform.rotate(
            pygame.image.load(PIPE_PATH).convert_alpha(), 180),
        pygame.image.load(PIPE_PATH).convert_alpha(),
    )

    # 计算水管图片的bool掩码
    #为水管图片生成碰撞掩码,并将它们存储在HITMASKS字典中。
    HITMASKS['pipe'] = (
        getHitmask(IMAGES['pipe'][0]),
        getHitmask(IMAGES['pipe'][1]),
    )

    # 生成小鸟图片的bool掩码
    HITMASKS['player'] = (
        getHitmask(IMAGES['player'][0]),
        getHitmask(IMAGES['player'][1]),
        getHitmask(IMAGES['player'][2]),
    )

    return IMAGES, SOUNDS, HITMASKS

# 生成图片的bool掩码矩阵,true表示对应像素位置不是透明的部分
def getHitmask(image):
    """returns a hitmask using an image's alpha."""
    mask = []
    for x in range(image.get_width()):#遍历所有的像素点
        mask.append([])
        for y in range(image.get_height()):
            mask[x].append(bool(image.get_at((x,y))[3]))    # 像素点是RGBA,例如:(83, 56, 70, 255),最后是透明度(0是透明,255是不透明)
    #对于图像中的每一个像素点,使用 image.get_at((x,y)) 获取该点的颜色值。颜色值通常以 RGBA(红色、绿色、蓝色、透明度)格式存储,其中 A 代表 Alpha 通道,即透明度。image.get_at((x,y))[3] 就是获取该像素点的 Alpha 值。
    return mask

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103

这里面需要解释的碰撞掩码是什么?

碰撞掩码(Collision Mask)是一种在计算机图形学和游戏开发中用于检测物体间碰撞的技术。它通常由一个布尔矩阵表示,其中每个像素点的值表示该点是否是物体的一部分。在处理碰撞检测时,通过比较两个物体的碰撞掩码可以判断它们是否重叠,从而确定是否发生了碰撞。

以下是碰撞掩码的一些关键点:

  1. 透明度判断:在许多游戏中,碰撞掩码是通过检查图像的透明度(Alpha通道)来生成的。如果图像的某个像素点是不透明的(例如,Alpha值为255),那么在碰撞掩码中对应的位置会被标记为True或实体部分;如果是透明的(Alpha值为0),则被标记为False或非实体部分。

  2. 简化碰撞检测:使用碰撞掩码可以避免直接对图像的每个像素点进行碰撞检测,这样可以显著提高碰撞检测的效率,尤其是在处理复杂图形或大规模场景时。

  3. 灵活性:碰撞掩码可以根据需要设计成不同的形状和大小,从而实现精确的碰撞检测。例如,一个角色的碰撞掩码可以是其轮廓的形状,而不仅仅是一个矩形或正方形。

  4. 性能优化:在游戏开发中,碰撞检测通常是一个计算密集型的过程。通过使用碰撞掩码,可以减少不必要的像素点比较,从而提高游戏性能。

  5. 应用场景:碰撞掩码不仅用于检测角色与障碍物之间的碰撞,还可以用于检测子弹与目标的碰撞、角色间的交互等。

getHitmask函数就是用来生成碰撞掩码的。它通过遍历图像的每个像素点,并检查其透明度来创建一个布尔矩阵。这个矩阵随后可以用于游戏中的碰撞检测逻辑,以判断小鸟是否与水管或其他物体发生了碰撞。

二、wrapped_flappy_bird.py

import numpy as np
import sys
import random
import pygame
from . import flappy_bird_utils
import pygame.surfarray as surfarray#用于将pygame的Surface对象转换为NumPy数组
from pygame.locals import *
from itertools import cycle#用于创建一个可循环的对象

# 屏幕宽*高
FPS = 30
SCREENWIDTH  = 288
SCREENHEIGHT = 512

# 初始化游戏,创建一个时钟对象来控制帧率,设置游戏窗口的尺寸和标题
pygame.init()
FPSCLOCK = pygame.time.Clock()  # FPS限速器
SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))   # 宽*高
pygame.display.set_caption('Flappy Bird')   # 标题

# 加载素材
IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()

PIPEGAPSIZE = 100 # 上下水管之间的距离是固定的100像素
BASEY = SCREENHEIGHT * 0.79 # 地面图片的y坐标

'''
地面图片在游戏窗口中的垂直位置。SCREENHEIGHT是游戏窗口的高度,
乘以0.79后得到一个值,这个值就是地面图片距离窗口顶部的像素距离。
因此,BASEY变量代表了地面图片在Y轴(垂直轴)上的位置,
它被设置在屏幕高度的79%的位置,这样地面图片会显示在屏幕的下半部分。
'''

# 小鸟图片的宽*高
PLAYER_WIDTH = IMAGES['player'][0].get_width()
PLAYER_HEIGHT = IMAGES['player'][0].get_height()
# 水管图片的宽*高
PIPE_WIDTH = IMAGES['pipe'][0].get_width()
PIPE_HEIGHT = IMAGES['pipe'][0].get_height()

# 背景图片的宽
BACKGROUND_WIDTH = IMAGES['background'].get_width()

# 创建一个循环对象,小鸟图片动画播放顺序
PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])

'''
0 表示第一张图片(翅膀上挥)
1 表示第二张图片(翅膀中挥)
2 表示第三张图片(翅膀下挥)
序列最后再次包含1,以实现翅膀的自然循环
'''

# Flappy bird游戏类
class GameState:
    def __init__(self):
        self.score = 0#初始化玩家的得分为 0
        self.playerIndex = 0#初始化玩家小鸟的当前动画索引为 0,这将决定小鸟显示哪一张动画图片
        self.loopIter = 0#初始化一个循环计数器,可能用于跟踪动画或游戏循环的次数

        # 玩家初始坐标
        self.playerx = int(SCREENWIDTH * 0.2)#设置玩家小鸟的初始 x 坐标,位于屏幕宽度的 20% 位置
        self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)#计算并设置玩家小鸟的初始 y 坐标,使得小鸟位于屏幕垂直居中的位置

        # 地面图片需要跑马灯效果,它比屏幕宽一点,每帧向左移动,当要耗尽时重新回到右边,如此往复
        self.basex = 0         # 地面图片的x坐标
        self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH  # 地面图片比屏幕宽度长多少像素,就是它可以移动的距离

        newPipe1 = getRandomPipe()  # 生成一对上下管子
        newPipe2 = getRandomPipe()  # 再生成一对上下管子

        # 上面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离
        self.upperPipes = [
            {'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
            {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
        ]
        # 下面2根管子,都放到屏幕右侧之外,x相邻半个屏幕距离
        self.lowerPipes = [
            {'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
            {'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
        ]

        # 水管的水平移动速度,每次x-4实现向左移动
        self.pipeVelX = -4

        # 小鸟Y方向速度
        self.playerVelY    =  0
        # 小鸟Y方向重力加速度,每帧作用域playerVelY,令其Y速度向下加大
        self.playerAccY    =   1
        # 点击后,小鸟Y方向速度重置为-9,也就是开始向上移动
        self.playerFlapAcc =  -9

        # 小鸟Y方向速度限制
        self.playerMaxVelY =  10   # Y向下最大速度10

    # 执行一次操作,返回操作后的画面、本次操作的奖励(活着+0.1,死了-1,飞过水管+1)、游戏是否结束
    def frame_step(self, input_actions):
        # 给pygame对积累的事件做一下默认处理
        pygame.event.pump()

        # 活着就奖励0.1分
        reward = 0.01
        # 是否死了
        terminal = False

        # 必须传有效的action,[1,0]表示不点击,[0,1]表示点击,全传0是不对的
        if sum(input_actions) != 1:#检查 input_actions 确保只有一个动作被执行
            raise ValueError('Multiple input actions!')

        # 每3帧换一次小鸟造型图片,loopIter统计经过了多少帧
        if (self.loopIter + 1) % 3 == 0:
            self.playerIndex = next(PLAYER_INDEX_GEN)
        self.loopIter += 1

        # 让地面向左移动,游戏开始的时候地面x=0,逐步减小x
        if self.basex + self.pipeVelX <= -self.baseShift:
            self.basex = 0
        else: # 图片即将滚动耗尽,重置x坐标
            self.basex += self.pipeVelX

        # 点击了屏幕
        if input_actions[1] == 1:
            self.playerVelY = self.playerFlapAcc # 将小鸟y方向速度重置为-9,也就是向上移动
            #SOUNDS['wing'].play()   # 播放扇翅膀的声音
        elif self.playerVelY < self.playerMaxVelY:  # 没点击屏幕并且没达到最大掉落速度,继续施加重力加速度
            self.playerVelY += self.playerAccY

        # 将速度施加到小鸟的y坐标上
        self.playery += self.playerVelY
        if self.playery < 0:    # 撞到上边缘不算死
            self.playery = 0 # 限制它别飞出去
        elif self.playery + PLAYER_HEIGHT >= BASEY: # 小鸟碰到地面
            self.playery = BASEY - PLAYER_HEIGHT # 限制它别穿地

        # 让上下水管都向左移动一次
        for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
            uPipe['x'] += self.pipeVelX
            lPipe['x'] += self.pipeVelX

        # 判断小鸟是否穿过了一排水管,因为上下水管x一样,只需要用上排水管判断
        playerMidPos = self.playerx + PLAYER_WIDTH / 2  # 小鸟中心的x坐标(这个是固定值,小鸟实际不会动,是水管在动)
        for pipe in self.upperPipes:    # 检查与上排水管的关系
            pipeMidPos = pipe['x'] + PIPE_WIDTH / 2 # 水管中心的x坐标
            if pipeMidPos <= playerMidPos < pipeMidPos + abs(self.pipeVelX): # 小鸟x坐标刚刚飞过了水管x中心(4是水管的移动速度)
                self.score += 1 # 游戏得分+1
                #SOUNDS['point'].play()
                reward = 100  # 产生强化学习的动作奖励10分

        # 最左侧水管马上离开屏幕,生成新水管
        if 0 < self.upperPipes[0]['x'] < 5:
            newPipe = getRandomPipe()
            self.upperPipes.append(newPipe[0])
            self.lowerPipes.append(newPipe[1])

        # 最左侧水管彻底离开屏幕,删除它的上下2根水管
        if self.upperPipes[0]['x'] < -PIPE_WIDTH:
            self.upperPipes.pop(0)
            self.lowerPipes.pop(0)

        # 检查小鸟是否碰到水管
        isCrash= checkCrash({'x': self.playerx, 'y': self.playery, 'index': self.playerIndex}, self.upperPipes, self.lowerPipes)
        if isCrash:  # 死掉了
            #SOUNDS['hit'].play()
            #SOUNDS['die'].play()
            reward = -10 # 负向激励分
            terminal = True # 本次操作导致游戏结束了

        ##### 进入重绘 #######

        # 贴背景图
        SCREEN.blit(IMAGES['background'], (0,0))
        # 画水管
        for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
            SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
            SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
        # 画地面
        SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
        # 画得分(训练时候别打开,造成干扰了)
        #showScore(self.score)
        # 画小鸟
        SCREEN.blit(IMAGES['player'][self.playerIndex], (self.playerx, self.playery))
        # 重绘
        pygame.display.update()
        # 留存游戏画面(截图是列优先存储的,需要转行行优先存储)
        # https://stackoverflow.com/questions/34673424/how-to-get-numpy-array-of-rgb-colors-from-pygame-surface
        image_data = pygame.surfarray.array3d(pygame.display.get_surface()).swapaxes(0,1)
        # 死亡则重置游戏状态
        if terminal:
            self.__init__()
        # 控制FPS
        FPSCLOCK.tick(FPS)
        return image_data, reward, terminal

# 生成一对水管,放到屏幕外面
def getRandomPipe():
    gapY = random.randint(70, 140)#生成一个介于 70 到 140 之间的随机整数,并将其赋值给变量 gapY。这个随机数决定了水管之间缝隙的上边缘的 y 坐标

    # 注:每一对水管的缝隙高度都是一样的PIPEGAPSIZE,gayY决定的是缝隙的上边缘y坐标
    pipeX = SCREENWIDTH + 10    # 水管出现在屏幕右侧之外

    return [
        {'x': pipeX, 'y': gapY - PIPE_HEIGHT},  # 计算上面水管图片的y坐标,就是缝隙上边缘y减去水管本身高度
        {'x': pipeX, 'y': gapY + PIPEGAPSIZE},  # 计算下面水管图片的y坐标,就是缝隙上边缘y加上缝隙本身高度
    ]

# 检查小鸟是否碰到水管或者地面(天花板不算)
def checkCrash(player, upperPipes, lowerPipes):
    pi = player['index']    # 小鸟的第几张图片

    # 图片的宽*高
    player['w'] = IMAGES['player'][pi].get_width()
    player['h'] = IMAGES['player'][pi].get_height()

    # 小鸟碰到了地面
    if player['y'] + player['h'] >= BASEY - 1:
        return True
    else: # 小鸟与水管进行碰撞检测
        # 小鸟图片的矩形区域
        playerRect = pygame.Rect(player['x'], player['y'], player['w'], player['h'])

        # 每一对水管
        for uPipe, lPipe in zip(upperPipes, lowerPipes):
            # 上面水管的矩形
            uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
            # 下面水管的矩形
            lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)

            # 小鸟图片的非透明像素掩码
            pHitMask = HITMASKS['player'][pi]
            # 上水管的非透明像素掩码
            uHitmask = HITMASKS['pipe'][0]
            # 下水管的非透明像素掩码
            lHitmask = HITMASKS['pipe'][1]

            # 检测小鸟与上面水管的碰撞
            uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
            # 检测小鸟与下面水管的碰撞
            lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)

            if uCollide or lCollide:
                return True
    return False


# 2个矩形区域的碰撞检测
def pixelCollision(rect1, rect2, hitmask1, hitmask2):
    '''
    rect1 和 rect2 是参与碰撞检测的两个矩形区域,通常是游戏中对象的位置和大小
    hitmask1 和 hitmask2 是与这两个矩形关联的碰撞掩码,它们是布尔数组,表示相应对象的哪些部分是实体(非透明)
    '''
    # 计算两个矩形的交集,即它们重叠的区域。如果没有重叠(即两个矩形没有碰撞),则 clip 方法返回一个宽度或高度为 0 的矩形
    rect = rect1.clip(rect2)

    # 相交面积为0
    if rect.width == 0 or rect.height == 0:
        return False

    # 相交矩形x,y相对于2个矩形左上角的距离
    x1, y1 = rect.x - rect1.x, rect.y - rect1.y
    #计算交集区域相对于 rect1 的相对位置
    x2, y2 = rect.x - rect2.x, rect.y - rect2.y#同理

    # 检查相交矩形内的每个点,是否在2个矩形内同时是非透明点,那么就碰撞了
    for x in range(rect.width):
        for y in range(rect.height):
            if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
                return True
    return False

# 展示得分,传入一个整数得分
def showScore(score):
    # 转成单个数字的列表
    scoreDigits = [int(x) for x in list(str(score))]
    #将得分 score 转换成字符串,然后将其每个字符(即每个单独的数字)转换成整数,并存储在列表 scoreDigits 中。这样,得分就被分解成了单个数字的列表

    # 计算展示所有数字要占多少像素宽度
    totalWidth = 0
    for digit in scoreDigits:
        totalWidth += IMAGES['numbers'][digit].get_width()
        '''
        遍历 scoreDigits 列表中的每个数字
        将每个数字图像的宽度累加到 totalWidth
        '''

    # 计算绘制起始x坐标
    Xoffset = (SCREENWIDTH - totalWidth) / 2

    # 逐个数字绘制
    for digit in scoreDigits:
        SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, 20))    # y坐标贴近屏幕上边缘
        Xoffset += IMAGES['numbers'][digit].get_width() # 移动绘制x坐标
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291

三、q_game.py

"""
强化学习q learning flappy bird
"""
from game.wrapped_flappy_bird import GameState
import time
import numpy as np 
import skimage.color
import skimage.transform
import skimage.exposure
import tensorflow as tf 
import random 
import argparse

# 命令行参数
parser = argparse.ArgumentParser()#创建一个 ArgumentParser 对象,用于定义命令行参数
parser.add_argument("--model-only", help="加载已有模型,不随机探索,仍旧训练", action='store_true')
args = parser.parse_args()#解析命令行输入的参数,并将它们存储在 args 变量中

# 测试用代码
def _test_save_img(img):
    # 把每一帧图片存储到文件里,调试用
    from PIL import Image
    im = Image.fromarray((img*255).astype(np.uint8), mode='L') # 图片已经被处理为0~1之间的亮度值,所以*255取整数变灰度展示
    im.save('./img.jpg')

# 构建卷积神经网络
def build_model():
    # 卷积神经网络:https://blog.csdn.net/FontThrone/article/details/76652753
    model = tf.keras.models.Sequential([#创建一个 Sequential 模型,它是 tf.keras 中用于线性堆叠网络层的模型类
        tf.keras.layers.Input(shape=(80,80,4)),
        tf.keras.layers.Conv2D(filters=32, kernel_size=(8, 8), padding='same',strides=4, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Conv2D(filters=64, kernel_size=(4, 4), padding='same',strides=2, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same',strides=1, activation='relu'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        tf.keras.layers.Flatten(),#将三维的卷积层输出展平为一维,以便传入到全连接层
        tf.keras.layers.Dense(256, activation='relu'),#定义一个具有 256 个单元的全连接层,并使用 ReLU 激活函数
        tf.keras.layers.Dense(2), # 对应2个action未来总回报预期
    ])
    model.compile(loss='mse', optimizer='adam')#编译模型,指定均方误差(MSE)作为损失函数,使用 Adam 优化器

    # 尝试加载之前保存的模型参数
    try:
        model.load_weights('./weights.h5')
        print('加载模型成功...................')
    except:
        pass
    return model

# 创建游戏
game = GameState()
# 卷积模型
model = build_model()

# 执行1帧游戏
def run_one_frame(action):
    global game 
    # image_data:执行动作后的图像(288*512*3的RGB三维数组)
    # reward:本次动作的奖励
    # terminal:游戏是否失败
    img, reward, terminal = game.frame_step(action)
    # RGB转灰度图
    img = skimage.color.rgb2gray(img)
    # 压缩到80*80的图片(根据RGB算出来的亮度,其数值很小)
    img = skimage.transform.resize(img, (80,80))
    # 把亮度标准化到0~1之间,用作模型输入
    img = skimage.exposure.rescale_intensity(img, out_range=(0,1))
    return img,reward,terminal

# 强化学习初始化状态
def reset_stat():
    # 执行第一帧,不点击
    img_t,_,_ =  run_one_frame([1,0])
    '''
    使用 numpy.stack 函数将首帧图像 img_t 重复四次,沿着第三个维度堆叠,
    形成初始状态 stat_t。这是因为卷积神经网络需要连续几帧的图像作为输入
    '''
    stat_t = np.stack([img_t] * 4, axis=2)
    return stat_t 

# 初始状态
stat_t = reset_stat()
# 训练样本
transitions = []#用于存储训练过程中的状态转换样本

# 时刻
t = 0

# 随机探索的概率控制,定义了随机探索概率的初始值、最终值和每次更新的步长。
INIT_EPSILON = 0.1
FINAL_EPSILON = 0.005
EPSLION_DELTA = 1e-6
# 最大留存样本个数
TRANS_CAP =  20000
# 至少有多少样本才训练
TRANS_SIZE_FIT = 10000
# 训练集大小
BATCH_SIZE = 32
# 未来激励折扣
GAMMA = 0.99

# 随机探索概率
if args.model_only: # 不随机探索(极低概率)
    epsilon = FINAL_EPSILON
else:
    epsilon = INIT_EPSILON

# 打印一些进度信息
rand_flap =0    # 随机点击次数
rand_noflap = 0 # 随机不点击次数
model_flap=0    # 模型点击次数
model_noflap=0  # 模型不点击次数
model_train_times = 0   # 模型训练次数

# 游戏启动
while True:    
    # 动作
    action_t = [0,0]

    action_type = '随机'#设置动作类型默认为 '随机',这将在选择动作时用于判断动作是随机选择的还是基于模型经验选择的。

    # 随着学习,降低随机探索的概率,让模型趋于稳定
    if (t <= TRANS_SIZE_FIT and not args.model_only) or random.random() <= epsilon:
        '''判断是否应该进行随机探索。如果在观察期内(t <= TRANS_SIZE_FIT)或者随机数小于或等于 epsilon,则执行随机探索'''
        n = random.random()
        if n <= 0.95:
            action_index = 0
            rand_noflap+=1
        else:
            action_index = 1
            rand_flap+=1
        #print('[随机探索] t时刻进行随机动作探索...')
    else: # 模型预测2个操作的未来累计回报
        action_type = '经验'
        Q_t = model.predict(np.expand_dims(stat_t, axis=0))[0]
        #使用当前的模型和状态 stat_t 来预测两个动作的未来总回报
        action_index = np.argmax(Q_t)   # 回报最大的action下标
        if action_index==0:
            model_noflap+=1
        else:
            model_flap+=1
        #print('[已有经验] 预测t时刻2个动作的未来总回报 -- 不点击:{} 点击:{}'.format(Q_t[0], Q_t[1]))

    action_t[action_index] = 1
    #print('时刻t将执行的动作为{}'.format(action_t))

    # 执行当前动作,返回操作后的图片、本次激励、游戏是否结束
    img_t1, reward, terminal = run_one_frame(action_t)
    _test_save_img(img_t1)
    img_t1 = img_t1.reshape((80,80,1)) # 增加通道维度,因为我们要最近4帧作为4通道图片,用作卷积模型输入
    stat_t1 = np.append(stat_t[:,:,1:], img_t1, axis=2) # 80*80*4,淘汰当前的第0通道,添加最新t1时刻到第3通道

    # 收集训练样本(保留有限的)
    transitions.append({
        'stat_t': stat_t,   # t时刻状态
        'stat_t1': stat_t1, # t1时刻状态
        'reward': reward,   # 本次动作的激励得分
        'terminal': terminal,   # 执行动作后游戏是否结束(ps: 结束意味着没有未来激励了)
        'action_index': action_index,   # 执行了什么动作(0:不点击,1:点击)
    })
    if len(transitions) > TRANS_CAP:
        transitions.pop(0)
    
    # 游戏结束则重置stat_t
    if terminal:
        stat_t = reset_stat()
        #print('死了!!!!!!! 状态t重置为初始帧...')
    else:   # 否则切为新的状态
        stat_t = stat_t1
        #print('没死~~~ 状态t切换为状态t1...')

    # 过了观察期,开始训练
    if t >= TRANS_SIZE_FIT and t % 10 == 0:
        minibatch = random.sample(transitions, BATCH_SIZE)
        # 模型训练的输入:t时刻的状态(最近4帧图片)
        inputs_t = np.concatenate([tran['stat_t'].reshape((1,80,80,4)) for tran in minibatch])
        #print('inputs_t shape', inputs_t.shape)
        ######################################################
        # 模型训练的输出:t时刻的未来总激励(Q_t = reward+gamma*Q_t1)
        # 1,让模型预测t时刻2种action的未来总激励
        Q_t = model.predict(inputs_t, batch_size=len(minibatch))
        # 2,让模型预测t1时刻2种action的未来总激励
        input_t1 = np.concatenate([tran['stat_t1'].reshape((1,80,80,4)) for tran in minibatch])
        Q_t1 = model.predict(input_t1, batch_size=len(minibatch))
        # 3,保留t1时刻2个action中最大的未来总激励
        Q_t1_max = [max(q) for q in Q_t1]
        # 4,t时刻进行action_index动作得到真实激励
        reward_t = [tran['reward'] for tran in minibatch]
        # 5,t时刻进行了什么action
        action_index_t = [tran['action_index'] for tran in minibatch]
        # 6,t1时刻是否死掉了
        terminal = [tran['terminal'] for tran in minibatch]
        # 7,修正训练的目标Q_t=reward+gamma*Q_t1
        # (t时刻action_index的未来总激励=action_index真实激励+t1时刻预测的最大未来总激励)
        for i in range(len(minibatch)):
            if terminal[i]:
                Q_t[i][action_index_t[i]] = reward_t[i] # 因为t1时刻已经死了,所以没有t1之后的累计激励
            else:
                Q_t[i][action_index_t[i]] = reward_t[i] + GAMMA*Q_t1_max[i] # Q_t=reward+Q_t1
        # print('Q_t shape', Q_t.shape)
        # 训练一波
        #print(inputs_t)
        #print(Q_t)
        model.fit(inputs_t, Q_t, batch_size=len(minibatch))
        model_train_times += 1
        # 训练1次则降低些许的随机探索概率
        if epsilon > FINAL_EPSILON:
            epsilon -= EPSLION_DELTA
        
        # 每5000次batch保存一次模型权重(不适用saved_model,后续加载只会加载权重,模型结构还是程序构造,因为这样可以保持keras model的api)
        if model_train_times % 5000 == 0:
            model.save_weights('./weights.h5')

        ######################################################
    if t % 100 == 0:
        print('总帧数:{} 剩余探索概率:{}% 累计训练次数:{} 累计随机点:{} 累计随机不点:{} 累计模型点:{} 累计模型不点:{} 训练集:{} '.format(
            t, round(epsilon * 100, 4), model_train_times, rand_flap, rand_noflap, model_flap, model_noflap,
            len(transitions)))
    t = t + 1
    #time.sleep(1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221

四、text_game.py

"""
演示pygame制作的flappy bird如何逐帧调用执行
"""
from game.wrapped_flappy_bird import GameState
from random import random
import time

# 创建游戏
game = GameState()

# 游戏启动
while True:
    r = random()
    if r <= 0.92:  # 92%的概率不点击屏幕
        game.frame_step([1,0]) # 动作:[1,0] 表示不点击
    else: # 8%的概率点击屏幕
        game.frame_step([0,1]) # 动作:[0,1] 表示点击
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

五、训练结果

请添加图片描述

代码源自:强化学习Deep Q-Network自动玩flappy bird | 鱼儿的博客 (yuerblog.cc)

仅想具体看一下工作原理和代码,仅供学习使用

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/760680
推荐阅读
相关标签
  

闽ICP备14008679号