PyTorch 命名为 Tensors 操作员范围
2025-07-02 17:58 更新
PyTorch 命名张量操作详解
一、命名张量操作概述
命名张量操作是 PyTorch 提供的一种增强张量操作可读性和安全性的机制。通过为张量的维度赋予名称,开发者可以在操作张量时更加直观地指定维度,避免因维度位置错误导致的程序错误。以下是命名张量支持的主要操作及其名称推断规则。
二、保留输入名称的操作
所有逐点一元函数以及其他一些一元函数都会保留输入张量的名称。
示例代码:
import torch
## 创建一个命名张量
x = torch.randn(3, 3, names=('N', 'C'))
## 应用 abs() 操作,保留输入名称
x_abs = x.abs()
print(x_abs.names) # 输出: ('N', 'C')
三、移除尺寸的操作
缩小操作(如 sum()
)会移除指定的尺寸。其他如 select()
和 squeeze()
等操作也会移除尺寸。
示例代码:
## 创建一个命名张量
x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
## 应用 squeeze() 操作,移除尺寸 'N'
x_squeezed = x.squeeze('N')
print(x_squeezed.names) # 输出: ('C', 'H', 'W')
## 应用 sum() 操作,移除尺寸 'N' 和 'C'
x_summed = x.sum(['N', 'C'])
print(x_summed.names) # 输出: ('H', 'W')
注意事项:
- 如果
dim
或dims
作为名称列表传入,需检查这些名称是否存在于输入张量中。 - 缩小操作的输出张量中不会包含被移除尺寸的名称。
四、统一输入名称的操作
所有二进制算术运算(如加法、乘法)都会统一输入张量的名称。
示例代码:
## 创建两个命名张量
x = torch.randn(3, 3, names=('N', 'C'))
y = torch.randn(3, 3, names=('N', 'C'))
## 应用加法操作,统一输入名称
x_plus_y = x + y
print(x_plus_y.names) # 输出: ('N', 'C')
## 创建另一个命名张量,维度名称不同
z = torch.randn(3, 3, names=('A', 'B'))
## 尝试加法操作,由于名称不匹配,会报错
## x_plus_z = x + z # 报错
注意事项:
- 所有名称必须从右侧位置匹配。即,在
tensor + other
中,对于(-min(tensor.dim(), other.dim()) + 1, -1]
中的所有i
,match(tensor.names[i], other.names[i])
必须为True
。 - 如果将命名维度与未命名维度匹配,则确保未命名维度中不存在该命名维度的名称。
五、排列尺寸的操作
某些操作(如 transpose()
)会改变维度顺序。命名张量会根据操作排列维度名称。
示例代码:
## 创建一个命名张量
x = torch.randn(3, 3, names=('N', 'C'))
## 应用 transpose() 操作,排列尺寸
x_transposed = x.transpose('N', 'C')
print(x_transposed.names) # 输出: ('C', 'N')
六、收缩消失的操作
矩阵乘法操作(如 mm()
)会收缩指定的维度,使其消失在输出张量中。
示例代码:
## 创建两个命名张量
x = torch.randn(3, 3, names=('N', 'D'))
y = torch.randn(3, 3, names=('D', 'M'))
## 应用 mm() 操作,收缩 'D' 维度
x_mm_y = torch.mm(x, y)
print(x_mm_y.names) # 输出: ('N', 'M')
七、工厂函数
工厂函数(如 torch.zeros()
、torch.randn()
)可以通过 names
参数创建命名张量。
示例代码:
## 使用工厂函数创建命名张量
named_zeros = torch.zeros(2, 3, names=('N', 'C'))
print(named_zeros.names) # 输出: ('N', 'C')
named_randn = torch.randn(2, 3, names=('N', 'C'))
print(named_randn.names) # 输出: ('N', 'C')
八、输出函数和就地变体
指定为 out=
参数的张量会根据操作传播名称。就地操作会修改输入张量的名称。
示例代码:
## 创建两个命名张量
x = torch.randn(3, 3, names=('N', 'C'))
y = torch.randn(3, 3, names=('N', 'C'))
## 创建一个未命名张量
z = torch.randn(3, 3)
## 应用加法操作,指定输出张量
torch.add(x, y, out=z)
print(z.names) # 输出: ('N', 'C')
## 应用就地加法操作
x += y
print(x.names) # 输出: ('N', 'C')
九、总结
通过本教程,我们详细介绍了 PyTorch 中命名张量的主要操作及其名称推断规则。这些规则确保了张量操作的正确性,并提高了代码的可读性和可维护性。掌握这些操作,可以帮助您更高效地开发深度学习模型,减少因维度错误导致的程序问题。
以上内容是否对您有帮助:
更多建议: