PyTorch XLA 设备上的 PyTorch

2025-06-25 10:39 更新

一、PyTorch XLA 简介

PyTorch XLA 是 PyTorch 的一个扩展,用于在 XLA 设备(如 TPU)上运行模型。它提供了与常规 PyTorch 类似的接口,但增加了一些额外功能以支持 XLA 设备。以下是使用 PyTorch XLA 的基本步骤和注意事项。

二、XLA 设备基础操作

2.1 创建和打印 XLA 张量

import torch
import torch_xla
import torch_xla.core.xla_model as xm


## 获取 XLA 设备
device = xm.xla_device()


## 创建 XLA 张量
t = torch.randn(2, 2, device=device)


## 打印设备和张量
print(t.device)
print(t)

2.2 XLA 张量的基本操作

XLA 张量支持与 CPU 和 CUDA 张量类似的操作,例如加法和矩阵乘法。

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)


## 加法操作
print(t0 + t1)


## 矩阵乘法操作
print(t0.mm(t1))

2.3 XLA 张量与神经网络模块的结合

XLA 张量可以与神经网络模块结合使用,进行模型训练和推断。

l_in = torch.randn(10, device=device)
linear = torch.nn.Linear(10, 20).to(device)
l_out = linear(l_in)
print(l_out)

三、模型训练与多设备支持

3.1 在单个 XLA 设备上训练模型

import torch_xla.core.xla_model as xm


device = xm.xla_device()
model = torch.nn.Linear(10, 2).train().to(device)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)


for data, target in train_loader:
    optimizer.zero_grad()
    data = data.to(device)
    target = target.to(device)
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer, barrier=True)

3.2 在多个 XLA 设备上进行并行训练

import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp


def _mp_fn(index):
    device = xm.xla_device()
    para_loader = pl.ParallelLoader(train_loader, [device])
    model = torch.nn.Linear(10, 2).train().to(device)
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    
    for data, target in para_loader.per_device_loader(device):
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        xm.optimizer_step(optimizer)


if __name__ == '__main__':
    xmp.spawn(_mp_fn, args=())

3.3 通过多线程在多个 XLA 设备上运行

import torch_xla.distributed.data_parallel as dp


devices = xm.get_xla_supported_devices()
model_parallel = dp.DataParallel(torch.nn.Linear, device_ids=devices)


def train_loop_fn(model, loader, device, context):
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    model.train()
    for data, target in loader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        xm.optimizer_step(optimizer)


for epoch in range(1, num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)

四、XLA 张量特性与优化

4.1 XLA 张量的懒惰执行特性

XLA 张量采用懒惰执行模式,将操作记录在图中,直到需要结果时才执行。这允许 XLA 对图进行优化。

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)


## 操作会被记录在图中,直到需要结果时才执行
t2 = t0 + t1
t3 = t2.mm(t0)

4.2 使用 bFloat16 数据类型

在 TPU 上运行时,PyTorch XLA 可以使用 bFloat16 数据类型,这可以通过设置 XLA_USE_BF16 环境变量来启用。

import os


## 启用 bFloat16
os.environ['XLA_USE_BF16'] = '1'


t = torch.randn(2, 2, device=device)
print(t.dtype)  # 将显示 torch.bfloat16

4.3 内存布局优化

XLA 张量的内部数据表示对用户透明,它们始终看起来是连续的。这使 XLA 可以调整张量的内存布局以获得更好的性能。

五、XLA 张量的保存与加载

5.1 将 XLA 张量移入和移出 CPU

XLA 张量可以从 CPU 移到 XLA 设备,也可以从 XLA 设备移到 CPU。

## 将张量从 CPU 移到 XLA 设备
cpu_tensor = torch.randn(2, 2)
xla_tensor = cpu_tensor.to(device)


## 将张量从 XLA 设备移回 CPU
cpu_tensor = xla_tensor.cpu()

5.2 保存和加载 XLA 张量

在保存 XLA 张量之前,应将其移至 CPU。

## 保存 XLA 张量
xla_tensor = torch.randn(2, 2, device=device)
cpu_tensor = xla_tensor.cpu()
torch.save(cpu_tensor, 'tensor.pt')


## 加载 XLA 张量
loaded_cpu_tensor = torch.load('tensor.pt')
loaded_xla_tensor = loaded_cpu_tensor.to(device)

六、常见问题解答

Q1:如何在多个 XLA 设备上进行并行训练?

A1:可以通过 torch_xla.distributed.xla_multiprocessing.spawn 创建多个进程,每个进程分别在不同的 XLA 设备上运行模型。

Q2:XLA 张量的懒惰执行特性如何影响性能?

A2:XLA 张量的懒惰执行特性允许 XLA 对操作图进行优化,从而提高执行效率。在需要结果时,XLA 会自动同步执行。

Q3:如何启用 bFloat16 数据类型?

A3:可以通过设置环境变量 XLA_USE_BF16=1 启用 bFloat16 数据类型。这在 TPU 上运行时可以提高性能。

七、完整示例:在 XLA 设备上训练模型

以下是一个完整的示例,展示了如何在 XLA 设备上训练一个简单的模型:

import torch
import torch.nn as nn
import torch.optim as optim
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp


## 定义模型
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 2)


    def forward(self, x):
        return self.fc(x)


## 单设备训练函数
def train_single_device():
    device = xm.xla_device()
    model = SimpleModel().train().to(device)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)


    # 假设 train_loader 是一个 DataLoader
    for data, target in train_loader:
        optimizer.zero_grad()
        data = data.to(device)
        target = target.to(device)
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)


## 多设备并行训练函数
def _mp_fn(index):
    device = xm.xla_device()
    para_loader = pl.ParallelLoader(train_loader, [device])
    model = SimpleModel().train().to(device)
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)


    for data, target in para_loader.per_device_loader(device):
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        xm.optimizer_step(optimizer)


if __name__ == '__main__':
    # 单设备训练
    train_single_device()


    # 多设备并行训练
    xmp.spawn(_mp_fn, args=())

八、总结与展望

通过本文的详细介绍,我们掌握了 PyTorch XLA 的基本使用方法,包括如何在 XLA 设备上创建和操作张量、训练模型以及利用多设备并行处理加速训练。希望这些内容能帮助你在实际项目中高效地利用 XLA 设备。

关注编程狮(W3Cschool)平台,获取更多 PyTorch XLA 开发相关的教程和案例。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号