PyTorch 命名为 Tensors 操作员范围

2025-07-02 17:58 更新

PyTorch 命名张量操作详解

一、命名张量操作概述

命名张量操作是 PyTorch 提供的一种增强张量操作可读性和安全性的机制。通过为张量的维度赋予名称,开发者可以在操作张量时更加直观地指定维度,避免因维度位置错误导致的程序错误。以下是命名张量支持的主要操作及其名称推断规则。

二、保留输入名称的操作

所有逐点一元函数以及其他一些一元函数都会保留输入张量的名称。

示例代码:

import torch


## 创建一个命名张量
x = torch.randn(3, 3, names=('N', 'C'))


## 应用 abs() 操作,保留输入名称
x_abs = x.abs()
print(x_abs.names)  # 输出: ('N', 'C')

三、移除尺寸的操作

缩小操作(如 sum())会移除指定的尺寸。其他如 select()squeeze() 等操作也会移除尺寸。

示例代码:

## 创建一个命名张量
x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))


## 应用 squeeze() 操作,移除尺寸 'N'
x_squeezed = x.squeeze('N')
print(x_squeezed.names)  # 输出: ('C', 'H', 'W')


## 应用 sum() 操作,移除尺寸 'N' 和 'C'
x_summed = x.sum(['N', 'C'])
print(x_summed.names)  # 输出: ('H', 'W')

注意事项:

  • 如果 dimdims 作为名称列表传入,需检查这些名称是否存在于输入张量中。
  • 缩小操作的输出张量中不会包含被移除尺寸的名称。

四、统一输入名称的操作

所有二进制算术运算(如加法、乘法)都会统一输入张量的名称。

示例代码:

## 创建两个命名张量
x = torch.randn(3, 3, names=('N', 'C'))
y = torch.randn(3, 3, names=('N', 'C'))


## 应用加法操作,统一输入名称
x_plus_y = x + y
print(x_plus_y.names)  # 输出: ('N', 'C')


## 创建另一个命名张量,维度名称不同
z = torch.randn(3, 3, names=('A', 'B'))


## 尝试加法操作,由于名称不匹配,会报错
## x_plus_z = x + z  # 报错

注意事项:

  • 所有名称必须从右侧位置匹配。即,在 tensor + other 中,对于 (-min(tensor.dim(), other.dim()) + 1, -1] 中的所有 imatch(tensor.names[i], other.names[i]) 必须为 True
  • 如果将命名维度与未命名维度匹配,则确保未命名维度中不存在该命名维度的名称。

五、排列尺寸的操作

某些操作(如 transpose())会改变维度顺序。命名张量会根据操作排列维度名称。

示例代码:

## 创建一个命名张量
x = torch.randn(3, 3, names=('N', 'C'))


## 应用 transpose() 操作,排列尺寸
x_transposed = x.transpose('N', 'C')
print(x_transposed.names)  # 输出: ('C', 'N')

六、收缩消失的操作

矩阵乘法操作(如 mm())会收缩指定的维度,使其消失在输出张量中。

示例代码:

## 创建两个命名张量
x = torch.randn(3, 3, names=('N', 'D'))
y = torch.randn(3, 3, names=('D', 'M'))


## 应用 mm() 操作,收缩 'D' 维度
x_mm_y = torch.mm(x, y)
print(x_mm_y.names)  # 输出: ('N', 'M')

七、工厂函数

工厂函数(如 torch.zeros()torch.randn())可以通过 names 参数创建命名张量。

示例代码:

## 使用工厂函数创建命名张量
named_zeros = torch.zeros(2, 3, names=('N', 'C'))
print(named_zeros.names)  # 输出: ('N', 'C')


named_randn = torch.randn(2, 3, names=('N', 'C'))
print(named_randn.names)  # 输出: ('N', 'C')

八、输出函数和就地变体

指定为 out= 参数的张量会根据操作传播名称。就地操作会修改输入张量的名称。

示例代码:

## 创建两个命名张量
x = torch.randn(3, 3, names=('N', 'C'))
y = torch.randn(3, 3, names=('N', 'C'))


## 创建一个未命名张量
z = torch.randn(3, 3)


## 应用加法操作,指定输出张量
torch.add(x, y, out=z)
print(z.names)  # 输出: ('N', 'C')


## 应用就地加法操作
x += y
print(x.names)  # 输出: ('N', 'C')

九、总结

通过本教程,我们详细介绍了 PyTorch 中命名张量的主要操作及其名称推断规则。这些规则确保了张量操作的正确性,并提高了代码的可读性和可维护性。掌握这些操作,可以帮助您更高效地开发深度学习模型,减少因维度错误导致的程序问题。

以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号