PyTorch 强化学习(DQN)教程
强化学习是机器学习领域中一个充满活力的分支,它研究如何使智能体在环境中通过试错的方式学习最优行为策略,以最大化累积奖励。深度 Q 网络(DQN)作为强化学习领域的一个重要突破,将深度学习的强大函数拟合能力与 Q 学习算法相结合,成功解决了高维状态空间下的强化学习问题。本文将带领读者深入浅出地学习如何使用 PyTorch 在 OpenAI Gym 的 CartPole-v0 任务上训练 DQN 智能体,开启强化学习的探索之旅。
一、环境搭建与准备工作
在开始训练 DQN 智能体之前,我们需要先搭建好开发环境并导入必要的软件包。
(一)安装依赖库
确保已安装 PyTorch、OpenAI Gym 和其他所需的依赖库。可以使用以下命令进行安装:
pip install torch gym matplotlib numpy pillow
(二)导入必要的模块
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
(三)初始化环境与设备
## 创建 CartPole-v0 环境
env = gym.make('CartPole-v0').unwrapped
## 设置设备(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
## 设置 matplotlib 交互模式
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
二、经验回放机制
经验回放是一种通过存储智能体与环境交互的经验,并从中随机采样进行学习的方法,它可以打破数据之间的相关性,提高模型的稳定性和收敛速度。
(一)定义 Transition 和 ReplayMemory
## 定义 Transition 命名元组
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
## 定义经验回放内存类
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, *args):
"""Saves a transition."""
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = Transition(*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
三、构建 DQN 模型
DQN 是一个卷积神经网络,用于根据当前状态预测每个动作的 Q 值,从而指导智能体选择最优动作。
(一)定义 DQN 网络结构
class DQN(nn.Module):
def __init__(self, h, w, outputs):
super(DQN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
self.bn2 = nn.BatchNorm2d(32)
self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
self.bn3 = nn.BatchNorm2d(32)
def conv2d_size_out(size, kernel_size=5, stride=2):
return (size - (kernel_size - 1) - 1) // stride + 1
convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
linear_input_size = convw * convh * 32
self.head = nn.Linear(linear_input_size, outputs)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
return self.head(x.view(x.size(0), -1))
四、输入处理与状态提取
从环境中提取智能体所需的当前状态信息,并进行预处理,使其适合作为神经网络的输入。
(一)定义图像处理函数
resize = T.Compose([
T.ToPILImage(),
T.Resize(40, interpolation=Image.CUBIC),
T.ToTensor()
])
def get_cart_location(screen_width):
world_width = env.x_threshold * 2
scale = screen_width / world_width
return int(env.state[0] * scale + screen_width / 2.0)
def get_screen():
screen = env.render(mode='rgb_array').transpose((2, 0, 1))
_, screen_height, screen_width = screen.shape
screen = screen[:, int(screen_height * 0.4):int(screen_height * 0.8)]
view_width = int(screen_width * 0.6)
cart_location = get_cart_location(screen_width)
if cart_location < view_width // 2:
slice_range = slice(view_width)
elif cart_location > (screen_width - view_width // 2):
slice_range = slice(-view_width, None)
else:
slice_range = slice(cart_location - view_width // 2, cart_location + view_width // 2)
screen = screen[:, :, slice_range]
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
return resize(screen).unsqueeze(0).to(device)
五、训练 DQN 智能体
(一)设置超参数与初始化模型
## 超参数设置
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
## 获取屏幕尺寸
init_screen = get_screen()
_, _, screen_height, screen_width = init_screen.shape
## 获取动作空间维度
n_actions = env.action_space.n
## 初始化策略网络和目标网络
policy_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net = DQN(screen_height, screen_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
## 定义优化器
optimizer = optim.RMSprop(policy_net.parameters())
## 初始化经验回放内存
memory = ReplayMemory(10000)
## 初始化步骤计数
steps_done = 0
## 定义选择动作的函数
def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
## 定义绘制训练曲线的函数
episode_durations = []
def plot_durations():
plt.figure(2)
plt.clf()
durations_t = torch.tensor(episode_durations, dtype=torch.float)
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(durations_t.numpy())
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())
plt.pause(0.001)
if is_ipython:
display.clear_output(wait=True)
display.display(plt.gcf())
(二)定义优化模型函数
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
(三)执行训练循环
## 训练智能体
num_episodes = 50
for i_episode in range(num_episodes):
env.reset()
last_screen = get_screen()
current_screen = get_screen()
state = current_screen - last_screen
for t in count():
action = select_action(state)
_, reward, done, _ = env.step(action.item())
reward = torch.tensor([reward], device=device)
last_screen = current_screen
current_screen = get_screen()
if not done:
next_state = current_screen - last_screen
else:
next_state = None
memory.push(state, action, next_state, reward)
state = next_state
optimize_model()
if done:
episode_durations.append(t + 1)
plot_durations()
break
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())
print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()
六、总结与展望
通过本文,您已成功使用 PyTorch 实现了 DQN 智能体,并在 CartPole-v0 任务上进行了训练。DQN 的核心思想是利用神经网络来近似 Q 函数,从而解决高维状态空间下的强化学习问题。在训练过程中,我们通过经验回放机制和目标网络来稳定学习过程,并采用 epsilon-greedy 策略来平衡探索与利用。
强化学习是一个广阔而深刻的领域,DQN 仅是其中的一颗明珠。未来,您可以进一步探索其他强化学习算法,如深度确定性策略梯度(DDPG)、 proximal 策略优化(PPO)等,以应对更复杂的连续动作空间和多智能体环境。编程狮将持续为您带来更多强化学习和深度学习的优质教程,助力您在人工智能的道路上不断前行。
更多建议: