找回密码
 立即注册
首页 业界区 安全 DQN算法

DQN算法

缣移双 2025-6-1 21:41:15
在Q-learning的学习过程中,我们需要维护一个 |S|x|A| 的Q表,当任务的状态空间和动作空间过大时,空间复杂度和时间复杂度都太高,为了解决这个问题,DQN采用神经网络来代替Q表,输入状态,预估该状态下采用不同动作的Q值
神经网络本身不是DQN的精髓,神经网络可以设计成MLP也可以设计成CNN等等,DQN的巧妙之处在于两个网络、经验回放等trick
 
Trick 1:两个网络

DQN算法采用了2个神经网络,分别是evaluate network(Q值网络)和target network(目标网络),两个网络结构完全相同:

  • evaluate network用用来计算策略选择的Q值和Q值迭代更新,梯度下降、反向传播的也是evaluate network
  • target network用来计算TD Target中下一状态的Q值,网络参数更新来自evaluate network网络参数复制
 设计target network目的是为了保持目标值稳定,防止过拟合,从而提高训练过程稳定和收敛速度
 
Trick 2:经验回放Experience Replay

DQN算法设计了一个固定大小的记忆库memory,用来记录经验,经验是一条一条的observation或者说是transition,它表示成 [s, a, r,s′] ,含义是当前状态→当前状态采取的动作→获得的奖励→转移到下一个状态
一开始记忆库memory中没有经验,也没有训练evaluate network,积累了一定数量的经验之后,再开始训练evaluate network。记忆库memory中的经验可以是自己历史的经验(epsilon-greedy得到的经验),也可以学习其他人的经验。训练evaluate network的时候,是从记忆库memory中随机选择batch size大小的经验,喂给evaluate network
设计记忆库memory并且随机选择经验喂给evaluate network的技巧打破了相邻训练样本之间相关性,试着想下,状态→动作→奖励→下一个状态的循环是具有关联的,用相邻的样本连续训练evaluate network会带来网络过拟合泛化能力差的问题,而经验回放技巧增强了训练样本之间的独立性
 
1.gif
2.gif
  1. import gym
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import random
  7. from collections import deque
  8. # 定义DQN网络
  9. class DQN(nn.Module):
  10.     def __init__(self, input_dim, output_dim):
  11.         super(DQN, self).__init__()
  12.         self.fc1 = nn.Linear(input_dim, 64)
  13.         self.fc2 = nn.Linear(64, output_dim)
  14.     def forward(self, x):
  15.         x = torch.relu(self.fc1(x))
  16.         x = self.fc2(x)
  17.         return x
  18. # 定义DQN智能体
  19. class DQNAgent:
  20.     def __init__(self, state_dim, action_dim):
  21.         self.state_dim = state_dim
  22.         self.action_dim = action_dim
  23.         self.gamma = 0.99  # 折扣因子
  24.         self.epsilon = 1.0  # 探索率
  25.         self.epsilon_min = 0.01
  26.         self.epsilon_decay = 0.995
  27.         self.learning_rate = 0.001
  28.         self.memory = deque(maxlen=2000)
  29.         self.model = DQN(state_dim, action_dim)
  30.         self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
  31.         self.criterion = nn.MSELoss()
  32.     def remember(self, state, action, reward, next_state, done):
  33.         self.memory.append((state, action, reward, next_state, done))
  34.     def act(self, state):
  35.         if np.random.rand() <= self.epsilon:
  36.             return random.randrange(self.action_dim)
  37.         state = torch.FloatTensor(state).unsqueeze(0)
  38.         q_values = self.model(state)
  39.         action = torch.argmax(q_values, dim=1).item()
  40.         return action
  41.     def replay(self, batch_size):
  42.         if len(self.memory) < batch_size:
  43.             return
  44.         minibatch = random.sample(self.memory, batch_size)
  45.         for state, action, reward, next_state, done in minibatch:
  46.             state = torch.FloatTensor(state).unsqueeze(0)
  47.             next_state = torch.FloatTensor(next_state).unsqueeze(0)
  48.             target = reward
  49.             if not done:
  50.                 target = (reward + self.gamma * torch.max(self.model(next_state)).item())
  51.             target_f = self.model(state)
  52.             target_f[0][action] = target
  53.             self.optimizer.zero_grad()
  54.             output = self.model(state)
  55.             loss = self.criterion(output, target_f)
  56.             loss.backward()
  57.             self.optimizer.step()
  58.         if self.epsilon > self.epsilon_min:
  59.             self.epsilon *= self.epsilon_decay
  60. # 训练函数
  61. def train_dqn(agent, env, episodes=500, batch_size=32):
  62.     for episode in range(episodes):
  63.         state = env.reset()
  64.         if isinstance(state, tuple):
  65.             state = state[0]
  66.         state = np.eye(env.observation_space.n)[state]
  67.         total_reward = 0
  68.         done = False
  69.         while not done:
  70.             action = agent.act(state)
  71.             next_state, reward, terminated, truncated, info = env.step(action)
  72.             done = terminated or truncated
  73.             next_state = np.eye(env.observation_space.n)[next_state]
  74.             agent.remember(state, action, reward, next_state, done)
  75.             agent.replay(batch_size)
  76.             state = next_state
  77.             total_reward += reward
  78.         print(f"Episode {episode + 1}: Total Reward = {total_reward}")
  79. # 主函数
  80. if __name__ == "__main__":
  81.     env = gym.make('CliffWalking-v0')
  82.     state_dim = env.observation_space.n
  83.     action_dim = env.action_space.n
  84.     agent = DQNAgent(state_dim, action_dim)
  85.     train_dqn(agent, env)
  86.     env.close()
复制代码
DQN 
参考资料

DQN基本概念和算法流程(附Pytorch代码)
 

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册