PyTorch 中的命名张量简介(实验性)

2025-06-18 17:14 更新

在数据科学和机器学习领域,张量操作是构建复杂模型的核心。PyTorch 作为一种广泛使用的深度学习框架,不断引入新特性以提升开发效率和代码可读性。命名张量(Named Tensor)作为 PyTorch 的一项实验性功能,旨在通过为张量维度赋予 meaningful 的名称,简化张量操作,减少因维度顺序错误导致的 bug,并提升代码的可维护性。本文将深入探讨 PyTorch 中命名张量的使用方法、优势以及在实际项目中的应用,帮助读者快速掌握这一强大工具。

一、命名张量基础知识

(一)创建命名张量

在 PyTorch 中,可以通过在创建张量时指定 names 参数来赋予张量维度名称。以下示例展示了如何创建具有命名维度的张量:

import torch


## 创建具有命名维度的张量
imgs = torch.randn(1, 2, 2, 3, names=('N', 'C', 'H', 'W'))
print(imgs.names)

(二)重命名与删除维度名称

命名张量的维度名称并非一成不变,可以通过以下方式对维度进行重命名或删除名称:

## 方法一:直接设置 .names 属性(原地操作)
imgs.names = ['batch', 'channel', 'width', 'height']
print(imgs.names)


## 方法二:指定新名称(创建新张量)
imgs = imgs.rename(channel='C', width='W', height='H')
print(imgs.names)


## 删除名称
imgs = imgs.rename(None)
print(imgs.names)

(三)混合命名与未命名张量

命名张量与未命名张量可以共存。若只想为部分维度指定名称,其余维度保持未命名状态,可以通过以下方式实现:

## 创建部分维度命名的张量
imgs = torch.randn(3, 1, 1, 2, names=('N', None, None, None))
print(imgs.names)

二、命名张量的操作与传播

(一)基本操作与名称传播

大多数张量操作(如 .abs())会保留维度名称,使得操作结果的可读性得以保持:

## 基本操作后名称传播
print(imgs.abs().names)

(二)索引与规约操作

可以通过维度名称进行索引和规约操作,使代码更具语义化:

## 按名称进行求和操作
output = imgs.sum('C')
print(output.names)


## 按名称选择特定维度数据
img0 = imgs.select('N', 0)
print(img0.names)

(三)名称推断机制

在张量操作过程中,PyTorch 会根据名称推断规则对输出张量的维度名称进行推断。这包括检查输入张量的名称是否匹配,并传播合适的名称至输出张量。

广播操作中的名称检查

命名张量在广播操作中会检查维度名称是否匹配,避免因维度对齐错误导致的意外结果:

## 广播操作中的名称检查
imgs = torch.randn(2, 2, 2, 2, names=('N', 'C', 'H', 'W'))
per_batch_scale = torch.rand(2, names=('N',))


## 尝试进行广播操作
try:
    imgs * per_batch_scale
except RuntimeError as e:
    print("错误信息:", e)

矩阵乘法中的名称传播

在矩阵乘法操作中,PyTorch 会根据输入张量的维度名称推断输出张量的维度名称:

## 矩阵乘法中的名称传播
markov_states = torch.randn(128, 5, names=('batch', 'D'))
transition_matrix = torch.randn(5, 5, names=('in', 'out'))
new_state = markov_states @ transition_matrix
print(new_state.names)

三、命名张量的新行为与应用

(一)按名称显式广播

命名张量支持通过 align_asalign_to 方法进行显式广播,使张量对齐操作更加直观:

## 按名称显式广播
imgs = imgs.refine_names('N', 'C', 'H', 'W')
per_batch_scale = per_batch_scale.refine_names('N')
named_result = imgs * per_batch_scale.align_as(imgs)

(二)按名称展平与展开维度

命名张量提供了 flattenunflatten 方法,支持按名称对维度进行展平和展开操作:

## 按名称展平维度
imgs = imgs.flatten(['C', 'H', 'W'], 'features')
print(imgs.names)


## 按名称展开维度
imgs = imgs.unflatten('features', (('C', 2), ('H', 2), ('W', 2)))
print(imgs.names)

四、命名张量在多头注意力模块中的应用

为了展示命名张量在实际项目中的优势,以下是一个使用命名张量实现多头注意力模块的示例:

import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, dim, dropout=0):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.attn_dropout = nn.Dropout(p=dropout)
        self.q_lin = nn.Linear(dim, dim)
        self.k_lin = nn.Linear(dim, dim)
        self.v_lin = nn.Linear(dim, dim)
        self.out_lin = nn.Linear(dim, dim)


    def forward(self, query, key=None, value=None, mask=None):
        query = query.refine_names(..., 'T', 'D')
        self_attn = key is None and value is None
        if self_attn:
            mask = mask.refine_names(..., 'T')
        else:
            mask = mask.refine_names(..., 'T', 'T_key')


        dim = query.size('D')
        n_heads = self.n_heads
        dim_per_head = dim // n_heads
        scale = math.sqrt(dim_per_head)


        def prepare_head(tensor):
            tensor = tensor.refine_names(..., 'T', 'D')
            return (tensor.unflatten('D', [('H', n_heads), ('D_head', dim_per_head)])
                          .align_to(..., 'H', 'T', 'D_head'))


        if self_attn:
            key = value = query
        elif value is None:
            key = key.refine_names(..., 'T', 'D')
            value = key


        k = prepare_head(self.k_lin(key)).rename(T='T_key')
        v = prepare_head(self.v_lin(value)).rename(T='T_key')
        q = prepare_head(self.q_lin(query))


        dot_prod = q.div_(scale).matmul(k.align_to(..., 'D_head', 'T_key'))
        dot_prod.refine_names(..., 'H', 'T', 'T_key')


        attn_mask = (mask == 0).align_as(dot_prod)
        dot_prod.masked_fill_(attn_mask, -float(1e20))
        attn_weights = self.attn_dropout(F.softmax(dot_prod, dim='T_key'))


        attentioned = (attn_weights.matmul(v).refine_names(..., 'H', 'T', 'D_head')
                           .align_to(..., 'T', 'H', 'D_head')
                           .flatten(['H', 'D_head'], 'D'))
        return self.out_lin(attentioned).refine_names(..., 'T', 'D')

五、总结与展望

命名张量作为 PyTorch 的一项创新特性,通过为张量维度赋予名称,极大地提升了代码的可读性和可维护性,减少了因维度顺序错误导致的 bug。本文详细介绍了命名张量的创建、操作、名称传播机制以及在多头注意力模块中的应用。尽管命名张量目前仍处于实验阶段,但其在提升开发效率和代码质量方面的潜力不容小觑。随着 PyTorch 的不断发展,命名张量有望成为深度学习开发中的标准工具之一。编程狮将持续关注 PyTorch 的最新动态,并为读者带来更多实用的深度学习技术教程。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号