PyTorch 中的命名张量简介(实验性)
在数据科学和机器学习领域,张量操作是构建复杂模型的核心。PyTorch 作为一种广泛使用的深度学习框架,不断引入新特性以提升开发效率和代码可读性。命名张量(Named Tensor)作为 PyTorch 的一项实验性功能,旨在通过为张量维度赋予 meaningful 的名称,简化张量操作,减少因维度顺序错误导致的 bug,并提升代码的可维护性。本文将深入探讨 PyTorch 中命名张量的使用方法、优势以及在实际项目中的应用,帮助读者快速掌握这一强大工具。
一、命名张量基础知识
(一)创建命名张量
在 PyTorch 中,可以通过在创建张量时指定 names
参数来赋予张量维度名称。以下示例展示了如何创建具有命名维度的张量:
import torch
## 创建具有命名维度的张量
imgs = torch.randn(1, 2, 2, 3, names=('N', 'C', 'H', 'W'))
print(imgs.names)
(二)重命名与删除维度名称
命名张量的维度名称并非一成不变,可以通过以下方式对维度进行重命名或删除名称:
## 方法一:直接设置 .names 属性(原地操作)
imgs.names = ['batch', 'channel', 'width', 'height']
print(imgs.names)
## 方法二:指定新名称(创建新张量)
imgs = imgs.rename(channel='C', width='W', height='H')
print(imgs.names)
## 删除名称
imgs = imgs.rename(None)
print(imgs.names)
(三)混合命名与未命名张量
命名张量与未命名张量可以共存。若只想为部分维度指定名称,其余维度保持未命名状态,可以通过以下方式实现:
## 创建部分维度命名的张量
imgs = torch.randn(3, 1, 1, 2, names=('N', None, None, None))
print(imgs.names)
二、命名张量的操作与传播
(一)基本操作与名称传播
大多数张量操作(如 .abs()
)会保留维度名称,使得操作结果的可读性得以保持:
## 基本操作后名称传播
print(imgs.abs().names)
(二)索引与规约操作
可以通过维度名称进行索引和规约操作,使代码更具语义化:
## 按名称进行求和操作
output = imgs.sum('C')
print(output.names)
## 按名称选择特定维度数据
img0 = imgs.select('N', 0)
print(img0.names)
(三)名称推断机制
在张量操作过程中,PyTorch 会根据名称推断规则对输出张量的维度名称进行推断。这包括检查输入张量的名称是否匹配,并传播合适的名称至输出张量。
广播操作中的名称检查
命名张量在广播操作中会检查维度名称是否匹配,避免因维度对齐错误导致的意外结果:
## 广播操作中的名称检查
imgs = torch.randn(2, 2, 2, 2, names=('N', 'C', 'H', 'W'))
per_batch_scale = torch.rand(2, names=('N',))
## 尝试进行广播操作
try:
imgs * per_batch_scale
except RuntimeError as e:
print("错误信息:", e)
矩阵乘法中的名称传播
在矩阵乘法操作中,PyTorch 会根据输入张量的维度名称推断输出张量的维度名称:
## 矩阵乘法中的名称传播
markov_states = torch.randn(128, 5, names=('batch', 'D'))
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
new_state = markov_states @ transition_matrix
print(new_state.names)
三、命名张量的新行为与应用
(一)按名称显式广播
命名张量支持通过 align_as
或 align_to
方法进行显式广播,使张量对齐操作更加直观:
## 按名称显式广播
imgs = imgs.refine_names('N', 'C', 'H', 'W')
per_batch_scale = per_batch_scale.refine_names('N')
named_result = imgs * per_batch_scale.align_as(imgs)
(二)按名称展平与展开维度
命名张量提供了 flatten
和 unflatten
方法,支持按名称对维度进行展平和展开操作:
## 按名称展平维度
imgs = imgs.flatten(['C', 'H', 'W'], 'features')
print(imgs.names)
## 按名称展开维度
imgs = imgs.unflatten('features', (('C', 2), ('H', 2), ('W', 2)))
print(imgs.names)
四、命名张量在多头注意力模块中的应用
为了展示命名张量在实际项目中的优势,以下是一个使用命名张量实现多头注意力模块的示例:
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, dim, dropout=0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
self.attn_dropout = nn.Dropout(p=dropout)
self.q_lin = nn.Linear(dim, dim)
self.k_lin = nn.Linear(dim, dim)
self.v_lin = nn.Linear(dim, dim)
self.out_lin = nn.Linear(dim, dim)
def forward(self, query, key=None, value=None, mask=None):
query = query.refine_names(..., 'T', 'D')
self_attn = key is None and value is None
if self_attn:
mask = mask.refine_names(..., 'T')
else:
mask = mask.refine_names(..., 'T', 'T_key')
dim = query.size('D')
n_heads = self.n_heads
dim_per_head = dim // n_heads
scale = math.sqrt(dim_per_head)
def prepare_head(tensor):
tensor = tensor.refine_names(..., 'T', 'D')
return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
.align_to(..., 'H', 'T', 'D_head'))
if self_attn:
key = value = query
elif value is None:
key = key.refine_names(..., 'T', 'D')
value = key
k = prepare_head(self.k_lin(key)).rename(T='T_key')
v = prepare_head(self.v_lin(value)).rename(T='T_key')
q = prepare_head(self.q_lin(query))
dot_prod = q.div_(scale).matmul(k.align_to(..., 'D_head', 'T_key'))
dot_prod.refine_names(..., 'H', 'T', 'T_key')
attn_mask = (mask == 0).align_as(dot_prod)
dot_prod.masked_fill_(attn_mask, -float(1e20))
attn_weights = self.attn_dropout(F.softmax(dot_prod, dim='T_key'))
attentioned = (attn_weights.matmul(v).refine_names(..., 'H', 'T', 'D_head')
.align_to(..., 'T', 'H', 'D_head')
.flatten(['H', 'D_head'], 'D'))
return self.out_lin(attentioned).refine_names(..., 'T', 'D')
五、总结与展望
命名张量作为 PyTorch 的一项创新特性,通过为张量维度赋予名称,极大地提升了代码的可读性和可维护性,减少了因维度顺序错误导致的 bug。本文详细介绍了命名张量的创建、操作、名称传播机制以及在多头注意力模块中的应用。尽管命名张量目前仍处于实验阶段,但其在提升开发效率和代码质量方面的潜力不容小觑。随着 PyTorch 的不断发展,命名张量有望成为深度学习开发中的标准工具之一。编程狮将持续关注 PyTorch 的最新动态,并为读者带来更多实用的深度学习技术教程。
更多建议: