赞
踩
自己用的DQN代码,大概理解了一些,随便记记
1.import部分
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
- from collections import deque
- import random
2.第一个类:构建神经网络
- class Net(nn.Module):
- def __init__(self, state_dim, action_dim):
- super(Net, self).__init__()
- self.state_dim = state_dim
- self.action_dim = action_dim
-
- self.fc1 = nn.Linear(self.state_dim, 64)
- self.fc1.weight.data.normal_(0, 0.1)
- self.fc2 = nn.Linear(64, 128)
- self.fc2.weight.data.normal_(0, 0.1)
- self.fc3 = nn.Linear(128, self.action_dim)
- self.fc3.weight.data.normal_(0, 0.1)
-
- def forward(self, x):
- x = x.view(x.size(0), x.size(-1))
- x = self.fc1(x)
- x = F.relu(x)
- x = self.fc2(x)
- x = F.relu(x)
- x = self.fc3(x)
- return x
self,__init__可参考如下:
解惑(一) ----- super(XXX, self).__init__()到底是代表什么含义_奋斗の博客-CSDN博客
自身想法:
2.1.class Net(nn.Module):
继承nn.Module对象,父类为nn.Module
2.2.__init__(self, state_dim, action_dim)
self为实例本身,该类下,self近似看为Net(对应的那个实例)
state_dim, action_dim为网络初始要传入的参数,分别为状态维数与动作维数(两个数)
2.3.super(Net, self).__init__()
Net类继承父类nn.Module
super(Net, self).__init__()就是对继承自父类nn.Module的属性进行初始化(这样自己可以少写一些初始化)
2.4.forward函数输入变量x为状态(tensor([[1*state_dim]]),原state为向量,加一维度再转为tensor)
网络最终输出x为动作(tensor([[1*action_dim]],grad_fn))
3.第二个类:经验池
- class replay_buffer(object):
- def __init__(self, capacity):
- self.capacity = capacity
- self.memory = deque(maxlen=self.capacity) #关键
-
- def store(self, state, action, reward, next_state ):
- state = np.expand_dims(state, 0)
- next_state = np.expand_dims(next_state, 0)
- self.memory.append([state, action, reward, next_state])
-
-
- def sample(self, size):
- batch = random.sample(self.memory, size)
- state, action, reward, next_state = zip(* batch)
- return np.concatenate(state, 0), action, reward, np.concatenate(next_state, 0)
-
- def __len__(self):
- return len(self.memory)
3.1.class replay_buffer(object):
继承object对象
3.2.def __init__(self, capacity):
该类初始化需要知道capacity(经验池容量)这一参数(一数值)
关键部分是创建经验池时,利用了deque(maxlen=self.capacity),类比于list,但可以保证容量满时,最老的那个数值删去,符合经验池的状况。
Python collections模块之deque()详解_chl183的博客-CSDN博客
3.3.def store(self, state, action, reward, next_state ):
np.expand_dims()_hong615771420的博客-CSDN博客
原state为一list [],维数扩展后为array [[]]
3.4.def sample(self, size):
经验池采样
batch = random.sample(self.memory, size)
原memory为一1*N的list,random.sample后变为一1*size的list。
state, action, reward, next_state = zip(* batch)
输出state为一元组,每一元素均为一array形式的state
np.concatenate(state, 0)
numpy数组拼接方法介绍_zyl1042635242的专栏-CSDN博客
完成数组拼接,将元组变为一array,每一行为一state元素
4.第三个类:DQN
-
- class DQN(object):
- def __init__(self,state_dim,action_dim,learning_rate):
- self.eval_net = Net(state_dim,action_dim)
- self.target_net = Net(state_dim,action_dim)
- self.learn_step_counter = 0
-
- self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=learning_rate)
- self.loss_fn = nn.MSELoss()
- # 学习率不变
-
- def get_action(self, state, epsilon,action_dim):
- state = torch.FloatTensor(np.expand_dims(state, 0))
- if random.random() < epsilon:
- action_value = self.eval_net.forward(state)
- action = action_value.max(1)[1].data[0].item()
- else:
- action = random.choice(list(range(action_dim)))
- return action
-
- def training(self,buffer, batch_size, gamma,target_replace_iter):
- # update the target net
- if self.learn_step_counter % target_replace_iter == 0:
- self.target_net.load_state_dict(self.eval_net.state_dict())
- self.learn_step_counter += 1
-
-
- state, action, reward, next_state = buffer.sample(batch_size)
- state = torch.FloatTensor(state)
- action = torch.LongTensor(action)
- reward = torch.FloatTensor(reward)
- next_state = torch.FloatTensor(next_state)
-
- q_values = self.eval_net.forward(state)
- q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
- next_q_values = self.target_net.forward(next_state)
- next_q_value = next_q_values.max(1)[0].detach()
- expected_q_value = reward + next_q_value * gamma
-
- loss = self.loss_fn(q_value, expected_q_value.detach())
-
- self.optimizer.zero_grad()
- loss.backward()
- self.optimizer.step()
-
- return loss
-
- def reset(self,state_dim, action_dim,learning_rate,memory):
- self.eval_net = Net(state_dim, action_dim)
- self.target_net = Net(state_dim, action_dim)
- self.learn_step_counter = 0
- self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=learning_rate)
-
- memory.clear()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。