PyTorch 分布式 RPC 框架入门

2025-06-23 09:56 更新

在分布式深度学习领域,PyTorch 的 torch.distributed.rpc 包提供了一种灵活且强大的机制,用于构建复杂的分布式应用。通过远程过程调用(RPC)和远程引用(RRef),开发者可以轻松地在不同进程中传递数据和方法调用,实现高效的分布式训练和推理。本文将通过详细的代码示例和深入的原理讲解,帮助您快速掌握 PyTorch 分布式 RPC 框架的核心概念和应用方法。

一、RPC 与 RRef:分布式通信的核心

(一)RPC 基础

RPC(远程过程调用)允许一个进程(客户端)调用另一个进程(服务器)中的函数,就像调用本地函数一样。在 PyTorch 的 torch.distributed.rpc 包中,RPC 提供了 rpc_syncrpc_async 两种调用方式,分别用于阻塞式和非阻塞式通信。

import torch.distributed.rpc as rpc


## 阻塞式 RPC 调用
result = rpc.rpc_sync("dest_worker", torch.add, args=(torch.tensor(2), 3))


## 非阻塞式 RPC 调用
future = rpc.rpc_async("dest_worker", torch.add, args=(torch.tensor(2), 3))
result = future.wait()

(二)RRef:远程对象引用

RRef(Remote Reference)用于在分布式环境中引用远程对象。它允许开发者在不同进程中共享和操作数据,而不必担心对象的物理位置。

from torch.distributed.rpc import RRef


## 创建远程对象
rref = rpc.remote("dest_worker", torch.randn, args=(3, 3))


## 获取远程对象的值
value = rref.to_here()

二、分布式强化学习示例

(一)定义策略网络

策略网络是强化学习中的核心组件,用于根据当前状态选择动作。

import torch.nn as nn
import torch.nn.functional as F


class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)


    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

(二)实现观察者和代理

观察者负责与环境交互,代理负责根据观察者收集的数据更新策略网络。

import gym
import torch.distributed.rpc as rpc


class Observer:
    def __init__(self, rank):
        self.rank = rank
        self.env = gym.make('CartPole-v1')


    def run_episode(self, agent_rref, n_steps):
        state = self.env.reset()
        for _ in range(n_steps):
            action = rpc.rpc_sync(agent_rref.owner(), _call_method, args=(Agent.select_action, agent_rref, self.rank, state))
            next_state, reward, done, _ = self.env.step(action)
            rpc.rpc_sync(agent_rref.owner(), _call_method, args=(Agent.report_reward, agent_rref, self.rank, reward))
            state = next_state
            if done:
                break


class Agent:
    def __init__(self, world_size):
        self.ob_rrefs = []
        self.policy = Policy()
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=1e-2)


    def select_action(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = torch.distributions.Categorical(probs)
        action = m.sample()
        return action.item()


    def report_reward(self, reward):
        # 保存奖励用于更新策略
        pass


    def run_episode(self, n_steps):
        # 触发观察者运行情节
        pass


    def finish_episode(self):
        # 更新策略网络
        pass

(三)启动分布式训练

import torch.multiprocessing as mp


def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 0:
        rpc.init_rpc("agent", rank=rank, world_size=world_size)
        agent = Agent(world_size)
        for episode in range(100):
            agent.run_episode(n_steps=10)
            agent.finish_episode()
    else:
        rpc.init_rpc(f"observer_{rank}", rank=rank, world_size=world_size)
    rpc.shutdown()


mp.spawn(run_worker, args=(2,), nprocs=2, join=True)

三、分布式模型并行训练示例

(一)定义分布式 RNN 模型

import torch.nn as nn


class EmbeddingTable(nn.Module):
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()


    def forward(self, input):
        return self.drop(self.encoder(input.cuda())).cpu()


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)


    def forward(self, output):
        return self.decoder(self.drop(output))


class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))


    def forward(self, input, hidden):
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

(二)实现分布式训练循环

from torch.distributed.autograd import context as dist_autograd
from torch.distributed.optim import DistributedOptimizer


def run_trainer():
    model = RNNModel('ps', ntoken=10, ninp=2, nhid=3, nlayers=4)
    opt = DistributedOptimizer(optim.SGD, model.parameter_rrefs(), lr=0.05)
    criterion = torch.nn.CrossEntropyLoss()


    def get_next_batch():
        for _ in range(5):
            data = torch.LongTensor(5, 3) % 10
            target = torch.LongTensor(5, 10) % 3
            yield data, target


    for epoch in range(10):
        for data, target in get_next_batch():
            with dist_autograd.context() as context_id:
                output, hidden = model(data, (torch.randn(4, 3, 3), torch.randn(4, 3, 3)))
                loss = criterion(output, target)
                dist_autograd.backward(context_id, [loss])
                opt.step(context_id)


def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        run_trainer()
    else:
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
    rpc.shutdown()


if __name__ == "__main__":
    mp.spawn(run_worker, args=(2,), nprocs=2, join=True)

四、总结与展望

通过本文的详细讲解,您已经掌握了 PyTorch 分布式 RPC 框架的核心概念和应用方法。RPC 和 RRef 提供了强大的工具,用于在分布式环境中构建复杂的模型和训练流程。未来,您可以进一步探索如何在实际项目中应用这些技术,以解决更大规模的模型训练和推理任务。编程狮将持续为您提供更多深度学习分布式计算的优质教程,助力您的技术成长。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号