PyTorch 强化学习(DQN)教程

2025-06-18 17:14 更新

强化学习是机器学习领域中一个充满活力的分支,它研究如何使智能体在环境中通过试错的方式学习最优行为策略,以最大化累积奖励。深度 Q 网络(DQN)作为强化学习领域的一个重要突破,将深度学习的强大函数拟合能力与 Q 学习算法相结合,成功解决了高维状态空间下的强化学习问题。本文将带领读者深入浅出地学习如何使用 PyTorch 在 OpenAI Gym 的 CartPole-v0 任务上训练 DQN 智能体,开启强化学习的探索之旅。

一、环境搭建与准备工作

在开始训练 DQN 智能体之前,我们需要先搭建好开发环境并导入必要的软件包。

(一)安装依赖库

确保已安装 PyTorch、OpenAI Gym 和其他所需的依赖库。可以使用以下命令进行安装:

  1. pip install torch gym matplotlib numpy pillow

(二)导入必要的模块

  1. import gym
  2. import math
  3. import random
  4. import numpy as np
  5. import matplotlib
  6. import matplotlib.pyplot as plt
  7. from collections import namedtuple
  8. from itertools import count
  9. from PIL import Image
  10. import torch
  11. import torch.nn as nn
  12. import torch.optim as optim
  13. import torch.nn.functional as F
  14. import torchvision.transforms as T

(三)初始化环境与设备

  1. ## 创建 CartPole-v0 环境
  2. env = gym.make('CartPole-v0').unwrapped
  3. ## 设置设备(GPU 或 CPU)
  4. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  5. ## 设置 matplotlib 交互模式
  6. is_ipython = 'inline' in matplotlib.get_backend()
  7. if is_ipython:
  8. from IPython import display
  9. plt.ion()

二、经验回放机制

经验回放是一种通过存储智能体与环境交互的经验,并从中随机采样进行学习的方法,它可以打破数据之间的相关性,提高模型的稳定性和收敛速度。

(一)定义 Transition 和 ReplayMemory

  1. ## 定义 Transition 命名元组
  2. Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
  3. ## 定义经验回放内存类
  4. class ReplayMemory(object):
  5. def __init__(self, capacity):
  6. self.capacity = capacity
  7. self.memory = []
  8. self.position = 0
  9. def push(self, *args):
  10. """Saves a transition."""
  11. if len(self.memory) < self.capacity:
  12. self.memory.append(None)
  13. self.memory[self.position] = Transition(*args)
  14. self.position = (self.position + 1) % self.capacity
  15. def sample(self, batch_size):
  16. return random.sample(self.memory, batch_size)
  17. def __len__(self):
  18. return len(self.memory)

三、构建 DQN 模型

DQN 是一个卷积神经网络,用于根据当前状态预测每个动作的 Q 值,从而指导智能体选择最优动作。

(一)定义 DQN 网络结构

  1. class DQN(nn.Module):
  2. def __init__(self, h, w, outputs):
  3. super(DQN, self).__init__()
  4. self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
  5. self.bn1 = nn.BatchNorm2d(16)
  6. self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
  7. self.bn2 = nn.BatchNorm2d(32)
  8. self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
  9. self.bn3 = nn.BatchNorm2d(32)
  10. def conv2d_size_out(size, kernel_size=5, stride=2):
  11. return (size - (kernel_size - 1) - 1) // stride + 1
  12. convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
  13. convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
  14. linear_input_size = convw * convh * 32
  15. self.head = nn.Linear(linear_input_size, outputs)
  16. def forward(self, x):
  17. x = F.relu(self.bn1(self.conv1(x)))
  18. x = F.relu(self.bn2(self.conv2(x)))
  19. x = F.relu(self.bn3(self.conv3(x)))
  20. return self.head(x.view(x.size(0), -1))

四、输入处理与状态提取

从环境中提取智能体所需的当前状态信息,并进行预处理,使其适合作为神经网络的输入。

(一)定义图像处理函数

  1. resize = T.Compose([
  2. T.ToPILImage(),
  3. T.Resize(40, interpolation=Image.CUBIC),
  4. T.ToTensor()
  5. ])
  6. def get_cart_location(screen_width):
  7. world_width = env.x_threshold * 2
  8. scale = screen_width / world_width
  9. return int(env.state[0] * scale + screen_width / 2.0)
  10. def get_screen():
  11. screen = env.render(mode='rgb_array').transpose((2, 0, 1))
  12. _, screen_height, screen_width = screen.shape
  13. screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
  14. view_width = int(screen_width * 0.6)
  15. cart_location = get_cart_location(screen_width)
  16. if cart_location < view_width // 2:
  17. slice_range = slice(view_width)
  18. elif cart_location > (screen_width - view_width // 2):
  19. slice_range = slice(-view_width, None)
  20. else:
  21. slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
  22. screen = screen[:, :, slice_range]
  23. screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
  24. screen = torch.from_numpy(screen)
  25. return resize(screen).unsqueeze(0).to(device)

五、训练 DQN 智能体

(一)设置超参数与初始化模型

  1. ## 超参数设置
  2. BATCH_SIZE = 128
  3. GAMMA = 0.999
  4. EPS_START = 0.9
  5. EPS_END = 0.05
  6. EPS_DECAY = 200
  7. TARGET_UPDATE = 10
  8. ## 获取屏幕尺寸
  9. init_screen = get_screen()
  10. _, _, screen_height, screen_width = init_screen.shape
  11. ## 获取动作空间维度
  12. n_actions = env.action_space.n
  13. ## 初始化策略网络和目标网络
  14. policy_net = DQN(screen_height, screen_width, n_actions).to(device)
  15. target_net = DQN(screen_height, screen_width, n_actions).to(device)
  16. target_net.load_state_dict(policy_net.state_dict())
  17. target_net.eval()
  18. ## 定义优化器
  19. optimizer = optim.RMSprop(policy_net.parameters())
  20. ## 初始化经验回放内存
  21. memory = ReplayMemory(10000)
  22. ## 初始化步骤计数
  23. steps_done = 0
  24. ## 定义选择动作的函数
  25. def select_action(state):
  26. global steps_done
  27. sample = random.random()
  28. eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
  29. steps_done += 1
  30. if sample > eps_threshold:
  31. with torch.no_grad():
  32. return policy_net(state).max(1)[1].view(1, 1)
  33. else:
  34. return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
  35. ## 定义绘制训练曲线的函数
  36. episode_durations = []
  37. def plot_durations():
  38. plt.figure(2)
  39. plt.clf()
  40. durations_t = torch.tensor(episode_durations, dtype=torch.float)
  41. plt.title('Training...')
  42. plt.xlabel('Episode')
  43. plt.ylabel('Duration')
  44. plt.plot(durations_t.numpy())
  45. if len(durations_t) >= 100:
  46. means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
  47. means = torch.cat((torch.zeros(99), means))
  48. plt.plot(means.numpy())
  49. plt.pause(0.001)
  50. if is_ipython:
  51. display.clear_output(wait=True)
  52. display.display(plt.gcf())

(二)定义优化模型函数

  1. def optimize_model():
  2. if len(memory) < BATCH_SIZE:
  3. return
  4. transitions = memory.sample(BATCH_SIZE)
  5. batch = Transition(*zip(*transitions))
  6. non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
  7. non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
  8. state_batch = torch.cat(batch.state)
  9. action_batch = torch.cat(batch.action)
  10. reward_batch = torch.cat(batch.reward)
  11. state_action_values = policy_net(state_batch).gather(1, action_batch)
  12. next_state_values = torch.zeros(BATCH_SIZE, device=device)
  13. next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
  14. expected_state_action_values = (next_state_values * GAMMA) + reward_batch
  15. loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
  16. optimizer.zero_grad()
  17. loss.backward()
  18. for param in policy_net.parameters():
  19. param.grad.data.clamp_(-1, 1)
  20. optimizer.step()

(三)执行训练循环

  1. ## 训练智能体
  2. num_episodes = 50
  3. for i_episode in range(num_episodes):
  4. env.reset()
  5. last_screen = get_screen()
  6. current_screen = get_screen()
  7. state = current_screen - last_screen
  8. for t in count():
  9. action = select_action(state)
  10. _, reward, done, _ = env.step(action.item())
  11. reward = torch.tensor([reward], device=device)
  12. last_screen = current_screen
  13. current_screen = get_screen()
  14. if not done:
  15. next_state = current_screen - last_screen
  16. else:
  17. next_state = None
  18. memory.push(state, action, next_state, reward)
  19. state = next_state
  20. optimize_model()
  21. if done:
  22. episode_durations.append(t + 1)
  23. plot_durations()
  24. break
  25. if i_episode % TARGET_UPDATE == 0:
  26. target_net.load_state_dict(policy_net.state_dict())
  27. print('Complete')
  28. env.render()
  29. env.close()
  30. plt.ioff()
  31. plt.show()

六、总结与展望

通过本文,您已成功使用 PyTorch 实现了 DQN 智能体,并在 CartPole-v0 任务上进行了训练。DQN 的核心思想是利用神经网络来近似 Q 函数,从而解决高维状态空间下的强化学习问题。在训练过程中,我们通过经验回放机制和目标网络来稳定学习过程,并采用 epsilon-greedy 策略来平衡探索与利用。

强化学习是一个广阔而深刻的领域,DQN 仅是其中的一颗明珠。未来,您可以进一步探索其他强化学习算法,如深度确定性策略梯度(DDPG)、 proximal 策略优化(PPO)等,以应对更复杂的连续动作空间和多智能体环境。编程狮将持续为您带来更多强化学习和深度学习的优质教程,助力您在人工智能的道路上不断前行。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号