当前位置:   article > 正文

DQN代码详解

dqn代码详解

自己用的DQN代码,大概理解了一些,随便记记

1.import部分

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. from collections import deque
  6. import random

2.第一个类:构建神经网络

  1. class Net(nn.Module):
  2. def __init__(self, state_dim, action_dim):
  3. super(Net, self).__init__()
  4. self.state_dim = state_dim
  5. self.action_dim = action_dim
  6. self.fc1 = nn.Linear(self.state_dim, 64)
  7. self.fc1.weight.data.normal_(0, 0.1)
  8. self.fc2 = nn.Linear(64, 128)
  9. self.fc2.weight.data.normal_(0, 0.1)
  10. self.fc3 = nn.Linear(128, self.action_dim)
  11. self.fc3.weight.data.normal_(0, 0.1)
  12. def forward(self, x):
  13. x = x.view(x.size(0), x.size(-1))
  14. x = self.fc1(x)
  15. x = F.relu(x)
  16. x = self.fc2(x)
  17. x = F.relu(x)
  18. x = self.fc3(x)
  19. return x

self,__init__可参考如下:

解惑(一) ----- super(XXX, self).__init__()到底是代表什么含义_奋斗の博客-CSDN博客

自身想法:

2.1.class Net(nn.Module):

   继承nn.Module对象,父类为nn.Module

2.2.__init__(self, state_dim, action_dim)

   self为实例本身,该类下,self近似看为Net(对应的那个实例)

   state_dim, action_dim为网络初始要传入的参数,分别为状态维数与动作维数(两个数)

2.3.super(Net, self).__init__()

    Net类继承父类nn.Module

    super(Net, self).__init__()就是对继承自父类nn.Module的属性进行初始化(这样自己可以少写一些初始化

2.4.forward函数输入变量x为状态(tensor([[1*state_dim]]),原state为向量,加一维度再转为tensor

   网络最终输出x为动作(tensor([[1*action_dim]],grad_fn)

3.第二个类:经验池

  1. class replay_buffer(object):
  2. def __init__(self, capacity):
  3. self.capacity = capacity
  4. self.memory = deque(maxlen=self.capacity) #关键
  5. def store(self, state, action, reward, next_state ):
  6. state = np.expand_dims(state, 0)
  7. next_state = np.expand_dims(next_state, 0)
  8. self.memory.append([state, action, reward, next_state])
  9. def sample(self, size):
  10. batch = random.sample(self.memory, size)
  11. state, action, reward, next_state = zip(* batch)
  12. return np.concatenate(state, 0), action, reward, np.concatenate(next_state, 0)
  13. def __len__(self):
  14. return len(self.memory)

3.1.class replay_buffer(object):

   继承object对象

3.2.def __init__(self, capacity):

   该类初始化需要知道capacity(经验池容量)这一参数(一数值

   关键部分是创建经验池时,利用了deque(maxlen=self.capacity),类比于list,但可以保证容量满时,最老的那个数值删去,符合经验池的状况。

Python collections模块之deque()详解_chl183的博客-CSDN博客

3.3.def store(self, state, action, reward, next_state ):

np.expand_dims()_hong615771420的博客-CSDN博客

   原state为一list  [],维数扩展后为array  [[]]

3.4.def sample(self, size):

   经验池采样

   batch = random.sample(self.memory, size)

   原memory为一1*N的list,random.sample后变为一1*size的list

   state, action, reward, next_state = zip(* batch)

   输出state为一元组,每一元素均为一array形式的state

   np.concatenate(state, 0)

    numpy数组拼接方法介绍_zyl1042635242的专栏-CSDN博客

   完成数组拼接,将元组变为一array,每一行为一state元素

4.第三个类:DQN

  1. class DQN(object):
  2. def __init__(self,state_dim,action_dim,learning_rate):
  3. self.eval_net = Net(state_dim,action_dim)
  4. self.target_net = Net(state_dim,action_dim)
  5. self.learn_step_counter = 0
  6. self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=learning_rate)
  7. self.loss_fn = nn.MSELoss()
  8. # 学习率不变
  9. def get_action(self, state, epsilon,action_dim):
  10. state = torch.FloatTensor(np.expand_dims(state, 0))
  11. if random.random() < epsilon:
  12. action_value = self.eval_net.forward(state)
  13. action = action_value.max(1)[1].data[0].item()
  14. else:
  15. action = random.choice(list(range(action_dim)))
  16. return action
  17. def training(self,buffer, batch_size, gamma,target_replace_iter):
  18. # update the target net
  19. if self.learn_step_counter % target_replace_iter == 0:
  20. self.target_net.load_state_dict(self.eval_net.state_dict())
  21. self.learn_step_counter += 1
  22. state, action, reward, next_state = buffer.sample(batch_size)
  23. state = torch.FloatTensor(state)
  24. action = torch.LongTensor(action)
  25. reward = torch.FloatTensor(reward)
  26. next_state = torch.FloatTensor(next_state)
  27. q_values = self.eval_net.forward(state)
  28. q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
  29. next_q_values = self.target_net.forward(next_state)
  30. next_q_value = next_q_values.max(1)[0].detach()
  31. expected_q_value = reward + next_q_value * gamma
  32. loss = self.loss_fn(q_value, expected_q_value.detach())
  33. self.optimizer.zero_grad()
  34. loss.backward()
  35. self.optimizer.step()
  36. return loss
  37. def reset(self,state_dim, action_dim,learning_rate,memory):
  38. self.eval_net = Net(state_dim, action_dim)
  39. self.target_net = Net(state_dim, action_dim)
  40. self.learn_step_counter = 0
  41. self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=learning_rate)
  42. memory.clear()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/296018
推荐阅读
相关标签
  

闽ICP备14008679号