当前位置:   article > 正文

DQN_tensorflow 源码解读_dqn开源代码

dqn开源代码
   最近根据课题需要在研究《Playing Atari with Deep Reinforcement Learning》这篇论文,也就是DeepMind的最原始的算法,该论文对应的开源代码很多,这里以github上的一个开源代码[https://github.com/gliese581gg/DQN_tensorflow]为例,理解深度强化学习的具体训练学习过程,代码是基于tensorflow,opencv的,本人对代码进行了详细的注释,希望对大家有所帮助
  • 1

Main函数,里面定义了deep_atari类,提供接口进行训练测试,其中params字典为网络进行相应的 具体看代码配置

from database import *
from emulator import *
import tensorflow as tf
import numpy as np
import time
from ale_python_interface import ALEInterface
import cv2
from scipy import misc
import gc #garbage colloector
import thread

gc.enable()

#给网络定义参数
params = {
    'visualize' : True,
    'network_type':'nips',
    'ckpt_file':None,
    'steps_per_epoch': 50000,
    'num_epochs': 100,
    'eval_freq':50000,
    'steps_per_eval':10000,
    'copy_freq' : 10000,
    'disp_freq':10000,
    'save_interval':10000,
    'db_size': 1000000,
    'batch': 32,
    'num_act': 0,
    'input_dims' : [210, 160, 3],
    'input_dims_proc' : [84, 84, 4],
    'learning_interval': 1,
    'eps': 1.0,
    'eps_step':1000000,
    'eps_min' : 0.1,
    'eps_eval' : 0.05,
    'discount': 0.95,
    'lr': 0.0002,
    'rms_decay':0.99,
    'rms_eps':1e-6,
    'train_start':100,
    'img_scale':255.0,
    'clip_delta' : 0, #nature : 1
    'gpu_fraction' : 0.25,
    'batch_accumulator':'mean',
    'record_eval' : True,
    'only_eval' : 'n'
}

class deep_atari:
    def __init__(self,params):
        print 'Initializing Module...'
        self.params = params

        self.gpu_config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.params['gpu_fraction']))

        self.sess = tf.Session(config=self.gpu_config)
        self.DB = database(self.params)#初始化replay memory
        self.engine = emulator(rom_name='breakout.bin', vis=self.params['visualize'],windowname=self.params['network_type']+'_preview')
        self.params['num_act'] = len(self.engine.legal_actions)#确认该游戏,action的数量
        self.build_net()#创建网络 qnet网络 和target网络
        self.training = True

    def build_net(self):
        print 'Building QNet and targetnet...'  

        '''qnet表示训练网络,target表示测试的网络,整个模型可以理解为一个游戏的初学者,
        在玩游戏的过程中,通过已有的经验的学会怎么玩游戏,在这里已有的经验为DB(replay memory)'''  
        self.qnet = DQN(self.params,'qnet')#定义qnet
        self.targetnet = DQN(self.params,'targetnet')#定义targetnet
        self.sess.run(tf.initialize_all_variables())
        saver_dict = {'qw1':self.qnet.w1,'qb1':self.qnet.b1,
                'qw2':self.qnet.w2,'qb2':self.qnet.b2,
                'qw3':self.qnet.w3,'qb3':self.qnet.b3,
                'qw4':self.qnet.w4,'qb4':self.qnet.b4,
                'qw5':self.qnet.w5,'qb5':self.qnet.b5,
                'tw1':self.targetnet.w1,'tb1':self.targetnet.b1,
                'tw2':self.targetnet.w2,'tb2':self.targetnet.b2,
                'tw3':self.targetnet.w3,'tb3':self.targetnet.b3,
                'tw4':self.targetnet.w4,'tb4':self.targetnet.b4,
                'tw5':self.targetnet.w5,'tb5':self.targetnet.b5,
                'step':self.qnet.global_step}#需要保存的save_dict的权值和偏置值
        self.saver = tf.train.Saver(saver_dict)
        #self.saver = tf.train.Saver()
        #复制qnet网络的权值和偏置值
        self.cp_ops = [
            self.targetnet.w1.assign(self.qnet.w1),self.targetnet.b1.assign(self.qnet.b1),
            self.targetnet.w2.assign(self.qnet.w2),self.targetnet.b2.assign(self.qnet.b2),
            self.targetnet.w3.assign(self.qnet.w3),self.targetnet.b3.assign(self.qnet.b3),
            self.targetnet.w4.assign(self.qnet.w4),self.targetnet.b4.assign(self.qnet.b4),
            self.targetnet.w5.assign(self.qnet.w5),self.targetnet.b5.assign(self.qnet.b5)]

        self.sess.run(self.cp_ops)

        if self.params['ckpt_file'] is not None:#恢复上一次训练的状态
            print 'loading checkpoint : ' + self.params['ckpt_file']
            self.saver.restore(self.sess,self.params['ckpt_file'])
            temp_train_cnt = self.sess.run(self.qnet.global_step)
            temp_step = temp_train_cnt * self.params['learning_interval']
            print 'Continue from'
            print '        -> Steps : ' + str(temp_step)
            print '        -> Minibatch update : ' + str(temp_train_cnt)


    def start(self):#网络开始学习和训练
        self.reset_game()#开始一个新的游戏
        self.step = 0#当前迭代次数
        self.reset_statistics('all')#重置网络所有参数
        self.train_cnt = self.sess.run(self.qnet.global_step)

        #如果是恢复上一层训练的状态则读取相应的文件
        if self.train_cnt > 0 :
            self.step = self.train_cnt * self.params['learning_interval']
            try:
                self.log_train = open('log_training_'+self.params['network_type']+'.csv','a')
            except:
                self.log_train = open('log_training_'+self.params['network_type']+'.csv','w')
                self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n')    

            try:
                self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','a')
            except:
                self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w')
                self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n')
        else:
                self.log_train = open('log_training_'+self.params['network_type']+'.csv','w')
                self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n')    
                self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w')
                self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n')

        self.s = time.time()
        #输出网络状态
        print self.params
        print 'Start training!'
        print 'Collecting replay memory for ' + str(self.params['train_start']) + ' steps'

        #开始进行迭代,训练,其中,params['train_start']表示用随机的权值跑游戏,获取最初始的replay memory
        while self.step < (self.params['steps_per_epoch'] * self.params['num_epochs'] * self.params['learning_interval'] + self.params['train_start']): 
            if self.training : 
                if self.DB.get_size() >= self.params['train_start'] : self.step += 1 ; self.steps_train += 1
            else : self.step_eval += 1

            #将上一次的状态,归一化之后的奖励值,以及所采取的action的索引,布尔类型的terminal 保存到DB里面(replay memory)
            if self.state_gray_old is not None and self.training:
                self.DB.insert(self.state_gray_old[26:110,:],self.reward_scaled,self.action_idx,self.terminal)

            #每隔params['copy_freq']的迭代次数,将训练的qnet网络超参数,复制到target网络
            if self.training and self.params['copy_freq'] > 0 and self.step % self.params['copy_freq'] == 0 and self.DB.get_size() > self.params['train_start']:
                print '&&& Copying Qnet to targetnet\n'
                self.sess.run(self.cp_ops)#???

            #每隔params['learning_interval']的迭代次数权值更新一次,注意的是params['learning_interval']=1表明每次action之后都得训练一次
            if self.training and self.step % self.params['learning_interval'] == 0 and self.DB.get_size() > self.params['train_start'] :

                '''从DB(replay memory中)随机选取batch个状态序列,供网络进行学习训练,
                具体的数据有状态s,动作a对应的索引,采取动作a之后的下一个状态,以及奖励值'''
                bat_s,bat_a,bat_t,bat_n,bat_r = self.DB.get_batches()
                bat_a = self.get_onehot(bat_a)#将action的索引值转换成一个稀疏矩阵,矩阵的行的大小表示batch_size,列的大小表示num_action,每行中对应的bat_a为1,其余为0

                #将游戏的当前状态,通过targetnet,将输出的q_t当作当前状态下的最大未来奖励
                if self.params['copy_freq'] > 0 :
                    feed_dict={self.targetnet.x: bat_n}
                    q_t = self.sess.run(self.targetnet.y,feed_dict=feed_dict)
                else:
                    feed_dict={self.qnet.x: bat_n}
                    q_t = self.sess.run(self.qnet.y,feed_dict=feed_dict)

                q_t = np.amax(q_t,axis=1)

                #这里将随机取出来的状态序列(可以理解成为经验),喂入qnet网络
                feed_dict={self.qnet.x: bat_s, self.qnet.q_t: q_t, self.qnet.actions: bat_a, self.qnet.terminals:bat_t, self.qnet.rewards: bat_r}

                #通过之前定义的qnet计算损失函数
                _,self.train_cnt,self.cost = self.sess.run([self.qnet.rmsprop,self.qnet.global_step,self.qnet.cost],feed_dict=feed_dict)

                #累计损失函数的计算
                self.total_cost_train += np.sqrt(self.cost)
                self.train_cnt_for_disp += 1

            if self.training :              
                self.params['eps'] = max(self.params['eps_min'],1.0 - float(self.train_cnt * self.params['learning_interval'])/float(self.params['eps_step']))
            else:
                self.params['eps'] = 0.05

            #每隔self.params['save_interval']迭代,进行权值保存 相当于caffe里面的snapshot    ,一下到228行都是一些辅助代码,显示输出啊,测试啊,等等,
            if self.DB.get_size() > self.params['train_start'] and self.step % self.params['save_interval'] == 0 and self.training:
                save_idx = self.train_cnt
                self.saver.save(self.sess,'ckpt/model_'+self.params['network_type']+'_'+str(save_idx))
                sys.stdout.write('$$$ Model saved : %s\n\n' % ('ckpt/model_'+self.params['network_type']+'_'+str(save_idx)))
                sys.stdout.flush()
            #输出显示
            if self.training and self.step > 0 and self.step % self.params['disp_freq']  == 0 and self.DB.get_size() > self.params['train_start'] : 
                self.write_log_train()

            #进行测试,这个时候相当于,只是用targetnet玩游戏,测试嘛。。qnet就被搁置了
            if self.training and self.step > 0 and self.step % self.params['eval_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : 

                self.reset_game()
                if self.step % self.params['steps_per_epoch'] == 0 : self.reset_statistics('all')
                else: self.reset_statistics('eval')
                self.training = False
                #TODO : add video recording             
                continue

            #训练时,每self.params['steps_per_epoch']步,重新开始游戏,因为往后对权值更新的作用不大
            #这里为解释一下,为什么会有这玩意,因为深度强化学习采用的是贝尔曼迭代法,有一个咖马参数,当当前游戏玩的步数比较多时,对当前的最大未来奖励就不是很大了,具体见论文
            if self.training and self.step > 0 and self.step % self.params['steps_per_epoch'] == 0 and self.DB.get_size() > self.params['train_start']: 
                self.reset_game()
                self.reset_statistics('all')
                #self.training = False
                continue

            if not self.training and self.step_eval >= self.params['steps_per_eval'] :
                self.write_log_eval()
                self.reset_game()
                self.reset_statistics('eval')
                self.training = True
                continue

            #判断游戏是否over
            if self.terminal :  
                self.reset_game()
                if self.training : 
                    self.num_epi_train += 1 
                    self.total_reward_train += self.epi_reward_train
                    self.epi_reward_train = 0
                else : 
                    self.num_epi_eval += 1 
                    self.total_reward_eval += self.epi_reward_eval
                    self.epi_reward_eval = 0
                continue
            '''这个函数表示选择下一步执行的action,select_action()这个函数采用了模拟退火算法的思想'''
            self.action_idx,self.action, self.maxQ = self.select_action(self.state_proc)

            #根据select_action函数获得的动作 执行,并且返回执行该动作后,所产生的状态,奖励,是否中断游戏等参数(构成马尔可夫链的一个新节点)
            self.state, self.reward, self.terminal = self.engine.next(self.action)
            self.reward_scaled = self.reward // max(1,abs(self.reward))#归一化奖励
            if self.training : self.epi_reward_train += self.reward ; self.total_Q_train += self.maxQ#累计奖励
            else : self.epi_reward_eval += self.reward ; self.total_Q_eval += self.maxQ 

            #以下程序段对新产生的状态进行处理,便于归入DB(replay memory,)
            self.state_gray_old = np.copy(self.state_gray)
            self.state_proc[:,:,0:3] = self.state_proc[:,:,1:4]
            self.state_resized = cv2.resize(self.state,(84,110))
            self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY)
            self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale']

            #TODO : add video recording

    def reset_game(self):
        self.state_proc = np.zeros((84,84,4)); self.action = -1; self.terminal = False; self.reward = 0
        self.state = self.engine.newGame()      
        self.state_resized = cv2.resize(self.state,(84,110))
        self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY)
        self.state_gray_old = None
        self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale']

    def reset_statistics(self,mode):
        if mode == 'all':
            self.epi_reward_train = 0
            self.epi_Q_train = 0
            self.num_epi_train = 0
            self.total_reward_train = 0
            self.total_Q_train = 0
            self.total_cost_train = 0
            self.steps_train = 0
            self.train_cnt_for_disp = 0
        self.step_eval = 0
        self.epi_reward_eval = 0
        self.epi_Q_eval = 0     
        self.num_epi_eval = 0       
        self.total_reward_eval = 0
        self.total_Q_eval = 0


    def write_log_train(self):
        sys.stdout.write('### Training (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] ))

        sys.stdout.write('    Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f, Avg.loss : %.3f\n' % (self.num_epi_train,float(self.total_reward_train)/max(1,self.num_epi_train),float(self.total_Q_train)/max(1,self.steps_train),self.total_cost_train/max(1,self.train_cnt_for_disp)))
        sys.stdout.write('    Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s))
        sys.stdout.flush()
        self.log_train.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',')
        self.log_train.write(str(float(self.total_reward_train)/max(1,self.num_epi_train)) +','+ str(float(self.total_Q_train)/max(1,self.steps_train)) +',')
        self.log_train.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n')
        self.log_train.flush()      

    def write_log_eval(self):
        sys.stdout.write('@@@ Evaluation (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] ))
        sys.stdout.write('    Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f\n' % (self.num_epi_eval,float(self.total_reward_eval)/max(1,self.num_epi_eval),float(self.total_Q_eval)/max(1,self.params['steps_per_eval'])))
        sys.stdout.write('    Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s))
        sys.stdout.flush()
        self.log_eval.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',')
        self.log_eval.write(str(float(self.total_reward_eval)/max(1,self.num_epi_eval)) +','+ str(float(self.total_Q_eval)/max(1,self.params['steps_per_eval'])) +',')
        self.log_eval.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n')
        self.log_eval.flush()

    def select_action(self,st):
        if np.random.rand() > self.params['eps']:#产生随机数,若self.params['eps']小于该随机数,则用qnet决定下一步执行的动作,注意这里采用的是qnet
            #greedy with random tie-breaking
            Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] 
            a_winner = np.argwhere(Q_pred == np.amax(Q_pred))
            if len(a_winner) > 1:
                act_idx = a_winner[np.random.randint(0, len(a_winner))][0]
                return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred)
            else:
                act_idx = a_winner[0][0]
                return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred)
        #若小于该随机数则随机产生一个动作进行执行
        else:
            #random
            act_idx = np.random.randint(0,len(self.engine.legal_actions))
            Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0]
            return act_idx,self.engine.legal_actions[act_idx], Q_pred[act_idx]

    def get_onehot(self,actions):
        actions_onehot = np.zeros((self.params['batch'], self.params['num_act']))

        for i in range(self.params['batch']):
            actions_onehot[i,actions[i]] = 1
        return actions_onehot


if __name__ == "__main__":
    dict_items = params.items()
    for i in range(1,len(sys.argv),2):
        if sys.argv[i] == '-weight' :params['ckpt_file'] = sys.argv[i+1]
        elif sys.argv[i] == '-network_type' :params['network_type'] = sys.argv[i+1]
        elif sys.argv[i] == '-visualize' :
            if sys.argv[i+1] == 'y' : params['visualize'] = True
            elif sys.argv[i+1] == 'n' : params['visualize'] = False
            else:
                print 'Invalid visualization argument!!! Available arguments are'
                print '        y or n'
                raise ValueError()
        elif sys.argv[i] == '-gpu_fraction' : params['gpu_fraction'] = float(sys.argv[i+1])
        elif sys.argv[i] == '-db_size' : params['db_size'] = int(sys.argv[i+1])
        elif sys.argv[i] == '-only_eval' : params['only_eval'] = sys.argv[i+1]
        else : 
            print 'Invalid arguments!!! Available arguments are'
            print '        -weight (filename)'
            print '        -network_type (nips or nature)'
            print '        -visualize (y or n)'
            print '        -gpu_fraction (0.1~0.9)'
            print '        -db_size (integer)'
            raise ValueError()
    if params['network_type'] == 'nips':
        from DQN_nips import *
    elif params['network_type'] == 'nature':
        from DQN_nature import *
        params['steps_per_epoch']= 200000
        params['eval_freq'] = 100000
        params['steps_per_eval'] = 10000
        params['copy_freq'] = 10000
        params['disp_freq'] = 20000
        params['save_interval'] = 20000
        params['learning_interval'] = 1
        params['discount'] = 0.99
        params['lr'] = 0.00025
        params['rms_decay'] = 0.95
        params['rms_eps']=0.01
        params['clip_delta'] = 1.0
        params['train_start']=50000
        params['batch_accumulator'] = 'sum'
        params['eps_step'] = 1000000
        params['num_epochs'] = 250
        params['batch'] = 32
    else :
        print 'Invalid network type! Available network types are'
        print '        nips or nature'
        raise ValueError()

    if params['only_eval'] == 'y' : only_eval = True
    elif params['only_eval'] == 'n' : only_eval = False
    else :
        print 'Invalid only_eval option! Available options are'
        print '        y or n'
        raise ValueError()

    if only_eval:
        params['eval_freq'] = 1
        params['train_start'] = 100

    da = deep_atari(params)
    da.start()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384

database类 该类实现了论文里面的replay memory

import numpy as np
import gc
import time
import cv2

class database:
    def __init__(self, params):
        self.size = params['db_size']
        self.img_scale = params['img_scale']
        self.states = np.zeros([self.size,84,84],dtype='uint8') #image dimensions
        self.actions = np.zeros(self.size,dtype='float32')
        self.terminals = np.zeros(self.size,dtype='float32')
        self.rewards = np.zeros(self.size,dtype='float32')
        self.bat_size = params['batch']
        self.bat_s = np.zeros([self.bat_size,84,84,4])
        self.bat_a = np.zeros([self.bat_size])
        self.bat_t = np.zeros([self.bat_size])
        self.bat_n = np.zeros([self.bat_size,84,84,4])
        self.bat_r = np.zeros([self.bat_size])

        self.counter = 0 #keep track of next empty state
        self.flag = False
        return

    def get_batches(self):#get random replay memory     
        for i in range(self.bat_size):#从replay memory提取,batch_size=32的序列数据
            idx = 0
            while idx < 3 or (idx > self.counter-2 and idx < self.counter+3):
                idx = np.random.randint(3,self.get_size()-1)#get_size()返回的是当前replay memory的状态个数,随机选取
            #以下是提取相应idx对应的值,并返回 
            self.bat_s[i] = np.transpose(self.states[idx-3:idx+1,:,:],(1,2,0))/self.img_scale
            self.bat_n[i] = np.transpose(self.states[idx-2:idx+2,:,:],(1,2,0))/self.img_scale
            self.bat_a[i] = self.actions[idx]
            self.bat_t[i] = self.terminals[idx]
            self.bat_r[i] = self.rewards[idx]
        #self.bat_s[0] = np.transpose(self.states[10:14,:,:],(1,2,0))/self.img_scale
        #self.bat_n[0] = np.transpose(self.states[11:15,:,:],(1,2,0))/self.img_scale
        #self.bat_a[0] = self.actions[13]
        #self.bat_t[0] = self.terminals[13]
        #self.bat_r[0] = self.rewards[13]

        return self.bat_s,self.bat_a,self.bat_t,self.bat_n,self.bat_r

    def insert(self, prevstate_proc,reward,action,terminal):#更新马尔可夫链
        self.states[self.counter] = prevstate_proc
        self.rewards[self.counter] = reward
        self.actions[self.counter] = action
        self.terminals[self.counter] = terminal
        #update counter
        self.counter += 1
        if self.counter >= self.size:
            self.flag = True
            self.counter = 0
        return

    def get_size(self):#返回当前replay 马尔可夫链的大小
        if self.flag == False:
            return self.counter
        else:
            return self.size

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

DQN网络,该代码的核心部分,定义了网络结构,贝尔曼函数以及损失函数

import numpy as np
import tensorflow as tf
import cv2

class DQN:
    def __init__(self,params,name):
        # 用tensorflow为马尔可夫节点的各个元素分配内存空间,输入[32,84,84,4](注:params['batch']=32)
        self.network_type = 'nature'
        self.params = params
        self.network_name = name
        self.x = tf.placeholder('float32',[None,84,84,4],name=self.network_name + '_x')
        self.q_t = tf.placeholder('float32',[None],name=self.network_name + '_q_t')
            self.actions = tf.placeholder("float32", [None, params['num_act']],name=self.network_name + '_actions')
        self.rewards = tf.placeholder("float32", [None],name=self.network_name + '_rewards')
        self.terminals = tf.placeholder("float32", [None],name=self.network_name + '_terminals')

        #conv1,[32,84,84,4]-->[32,w1,h1,32](卷积后的w,h,根据公式:h=(h+2*padding-stride)/stride+1,自行计算)
        layer_name = 'conv1' ; size = 8 ; channels = 4 ; filters = 32 ; stride = 4
        self.w1 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights')
        self.b1 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases')
        self.c1 = tf.nn.conv2d(self.x, self.w1, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs')
        self.o1 = tf.nn.relu(tf.add(self.c1,self.b1),name=self.network_name + '_'+layer_name+'_activations')
        #self.n1 = tf.nn.lrn(self.o1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

        #conv2,[32,w1,h1,32]-->[32,w2,h2,64]
        layer_name = 'conv2' ; size = 4 ; channels = 32 ; filters = 64 ; stride = 2
        self.w2 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights')
        self.b2 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases')
        self.c2 = tf.nn.conv2d(self.o1, self.w2, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs')
        self.o2 = tf.nn.relu(tf.add(self.c2,self.b2),name=self.network_name + '_'+layer_name+'_activations')
        #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

        #conv3,[32,w2,h2,64]-->[32,w3,h3,64]
        layer_name = 'conv3' ; size = 3 ; channels = 64 ; filters = 64 ; stride = 1
        self.w3 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights')
        self.b3 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases')
        self.c3 = tf.nn.conv2d(self.o2, self.w3, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs')
        self.o3 = tf.nn.relu(tf.add(self.c3,self.b3),name=self.network_name + '_'+layer_name+'_activations')
        #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)

        #flat,将矩阵压缩成向量
        o3_shape = self.o3.get_shape().as_list()        

        #fc3,[32,w3*h3*64]-->[32,512]
        layer_name = 'fc4' ; hiddens = 512 ; dim = o3_shape[1]*o3_shape[2]*o3_shape[3]
        self.o3_flat = tf.reshape(self.o3, [-1,dim],name=self.network_name + '_'+layer_name+'_input_flat')
        self.w4 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights')
        self.b4 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases')
        self.ip4 = tf.add(tf.matmul(self.o3_flat,self.w4),self.b4,name=self.network_name + '_'+layer_name+'_ips')
        self.o4 = tf.nn.relu(self.ip4,name=self.network_name + '_'+layer_name+'_activations')

        #fc4,[32,512]-->[32,number_act]
        layer_name = 'fc5' ; hiddens = params['num_act'] ; dim = 512
        self.w5 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights')
        self.b5 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases')
        ''' 在这里说一说,deep-Q-Learning的核心思想之一,就是在玩游戏时,使下一步的最大未来奖励最大,但是在实际过程,
        不可能从未来向现在求累加,于是就用了贝尔曼迭代公式,就是假设下一个状态的最大未来奖励已经求解出来了,
        这样就可以通过target神经网络求出该状态下的最大未来奖励(这个值做为qnet神经网络训练样本的期望值,但是,把它叫做label),
        这种思想,类似于算法里面的递归,数学里面的归纳法,好了,现在再来说说深度神经网络的作用,它的输入就是当前状态(是一个连续的4张图像),
        输出是下一状态下,不同action对应的最大未来收益,选取最大的输出值,就是该状态下的最大未来收益值,也就是神经网络的预测值,这里有意思的是期望值和预测值都是通过神经网络
        求出来的(两个神经网络),训练的过程就和正常卷积神经网络的也就一样了,选择损失函数,误差反向传递,更新权值'''
        self.y = tf.add(tf.matmul(self.o4,self.w5),self.b5,name=self.network_name + '_'+layer_name+'_outputs')

        #Q,Cost,Optimizer
        self.discount = tf.constant(self.params['discount'])#贝尔曼迭代公式的咖马

        #贝尔曼迭代公式,计算出来的结果叫做打折后的的最大未来奖励
        self.yj = tf.add(self.rewards, tf.mul(1.0-self.terminals, tf.mul(self.discount, self.q_t)))
        self.Qxa = tf.mul(self.y,self.actions)
        self.Q_pred = tf.reduce_max(self.Qxa, reduction_indices=1)
        #self.yjr = tf.reshape(self.yj,(-1,1))
        #self.yjtile = tf.concat(1,[self.yjr,self.yjr,self.yjr,self.yjr])
        #self.yjax = tf.mul(self.yjtile,self.actions)

        #half = tf.constant(0.5)
        self.diff = tf.sub(self.yj, self.Q_pred)
        if self.params['clip_delta'] > 0 :
            self.quadratic_part = tf.minimum(tf.abs(self.diff), tf.constant(self.params['clip_delta']))###?????
            self.linear_part = tf.sub(tf.abs(self.diff),self.quadratic_part)
            self.diff_square = 0.5 * tf.pow(self.quadratic_part,2) + self.params['clip_delta']*self.linear_part


        else:
            self.diff_square = tf.mul(tf.constant(0.5),tf.pow(self.diff, 2))

        if self.params['batch_accumulator'] == 'sum':
            self.cost = tf.reduce_sum(self.diff_square)
        else:
            self.cost = tf.reduce_mean(self.diff_square)

        self.global_step = tf.Variable(0, name='global_step', trainable=False)
        self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92

emulator类 定义了atari的游戏接口,可以获取当前的状态(图像),reward,重置游戏函数,新建游戏函数

import numpy as np
import copy
import sys
from ale_python_interface import ALEInterface
import cv2
import time
#import scipy.misc

class emulator:
    def __init__(self, rom_name, vis,windowname='preview'):
        self.ale = ALEInterface()
        self.max_frames_per_episode = self.ale.getInt("max_num_frames_per_episode");
        self.ale.setInt("random_seed",123)
        self.ale.setInt("frame_skip",4)
        self.ale.loadROM('roms/' + rom_name )
        self.legal_actions = self.ale.getMinimalActionSet()
        self.action_map = dict()
        self.windowname = windowname
        for i in range(len(self.legal_actions)):
            self.action_map[self.legal_actions[i]] = i

        # print(self.legal_actions)
        self.screen_width,self.screen_height = self.ale.getScreenDims()
        print("width/height: " +str(self.screen_width) + "/" + str(self.screen_height))
        self.vis = vis
        if vis: 
            cv2.startWindowThread()
            cv2.namedWindow(self.windowname)

    def get_image(self):#读取画面的图像
        numpy_surface = np.zeros(self.screen_height*self.screen_width*3, dtype=np.uint8)
        self.ale.getScreenRGB(numpy_surface)
        image = np.reshape(numpy_surface, (self.screen_height, self.screen_width, 3))
        return image

    def newGame(self):
        self.ale.reset_game()#开始新的游戏
        return self.get_image()#得到初始的游戏画面

    def next(self, action_indx):
        reward = self.ale.act(action_indx)  
        nextstate = self.get_image()
        # scipy.misc.imsave('test.png',nextstate)
        if self.vis:
            cv2.imshow(self.windowname,nextstate)
        return nextstate, reward, self.ale.game_over()



if __name__ == "__main__":
    engine = emulator('breakout.bin',True)
    engine.next(0)
    time.sleep(5)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

论文链接:https://arxiv.org/pdf/1312.5602.pdf
代码链接:https://github.com/gliese581gg/DQN_tensorflow

晚点附上,这篇论文的论文笔记

若存在不足之处,请批评指正

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/295950
推荐阅读
相关标签
  

闽ICP备14008679号