PyTorch 命名为 Tensors 操作员范围
请首先阅读命名张量,以了解命名张量。
本文档是名称推断的参考,HTH1 是一个定义张量命名方式的过程:
- 使用名称提供其他自动运行时正确性检查
- 将名称从输入张量传播到输出张量
以下是命名张量及其关联的名称推断规则支持的所有操作的列表。
如果此处未列出操作,但对您的用例有帮助,请搜索问题是否已提交,否则请提交一个问题。
警告
命名的张量 API 是实验性的,随时可能更改。
Supported Operations
API | 名称推断规则 |
---|---|
Tensor.abs() , torch.abs() |
保留输入名称 |
Tensor.abs_() |
Keeps input names |
Tensor.acos() , torch.acos() |
Keeps input names |
Tensor.acos_() |
Keeps input names |
Tensor.add() , torch.add() |
统一输入的名称 |
Tensor.add_() |
Unifies names from inputs |
Tensor.addmm() , torch.addmm() |
缩小暗淡 |
Tensor.addmm_() |
Contracts away dims |
Tensor.addmv() , torch.addmv() |
Contracts away dims |
Tensor.addmv_() |
Contracts away dims |
Tensor.align_as() |
查看文件 |
Tensor.align_to() |
See documentation |
Tensor.all() ,torch.all() |
没有 |
Tensor.any() ,torch.any() |
None |
Tensor.asin() , torch.asin() |
Keeps input names |
Tensor.asin_() |
Keeps input names |
Tensor.atan() , torch.atan() |
Keeps input names |
Tensor.atan2() , torch.atan2() |
Unifies names from inputs |
Tensor.atan2_() |
Unifies names from inputs |
Tensor.atan_() |
Keeps input names |
Tensor.bernoulli() , torch.bernoulli() |
Keeps input names |
Tensor.bernoulli_() |
None |
Tensor.bfloat16() |
Keeps input names |
Tensor.bitwise_not() , torch.bitwise_not() |
Keeps input names |
Tensor.bitwise_not_() |
None |
Tensor.bmm() , torch.bmm() |
Contracts away dims |
Tensor.bool() |
Keeps input names |
Tensor.byte() |
Keeps input names |
torch.cat() |
Unifies names from inputs |
Tensor.cauchy_() |
None |
Tensor.ceil() , torch.ceil() |
Keeps input names |
Tensor.ceil_() |
None |
Tensor.char() |
Keeps input names |
Tensor.chunk() , torch.chunk() |
Keeps input names |
Tensor.clamp() , torch.clamp() |
Keeps input names |
Tensor.clamp_() |
None |
Tensor.copy_() |
输出功能和就地变体 |
Tensor.cos() , torch.cos() |
Keeps input names |
Tensor.cos_() |
None |
Tensor.cosh() , torch.cosh() |
Keeps input names |
Tensor.cosh_() |
None |
Tensor.cpu() |
Keeps input names |
Tensor.cuda() |
Keeps input names |
Tensor.cumprod() , torch.cumprod() |
Keeps input names |
Tensor.cumsum() , torch.cumsum() |
Keeps input names |
Tensor.data_ptr() |
None |
Tensor.detach() ,torch.detach() |
Keeps input names |
Tensor.detach_() |
None |
Tensor.device , torch.device() |
None |
Tensor.digamma() , torch.digamma() |
Keeps input names |
Tensor.digamma_() |
None |
Tensor.dim() |
None |
Tensor.div() , torch.div() |
Unifies names from inputs |
Tensor.div_() |
Unifies names from inputs |
Tensor.dot() , torch.dot() |
None |
Tensor.double() |
Keeps input names |
Tensor.element_size() |
None |
torch.empty() |
工厂功能 |
torch.empty_like() |
Factory functions |
Tensor.eq() , torch.eq() |
Unifies names from inputs |
Tensor.erf() , torch.erf() |
Keeps input names |
Tensor.erf_() |
None |
Tensor.erfc() , torch.erfc() |
Keeps input names |
Tensor.erfc_() |
None |
Tensor.erfinv() , torch.erfinv() |
Keeps input names |
Tensor.erfinv_() |
None |
Tensor.exp() , torch.exp() |
Keeps input names |
Tensor.exp_() |
None |
Tensor.expand() |
Keeps input names |
Tensor.expm1() , torch.expm1() |
Keeps input names |
Tensor.expm1_() |
None |
Tensor.exponential_() |
None |
Tensor.fill_() |
None |
Tensor.flatten() , torch.flatten() |
See documentation |
Tensor.float() |
Keeps input names |
Tensor.floor() , torch.floor() |
Keeps input names |
Tensor.floor_() |
None |
Tensor.frac() , torch.frac() |
Keeps input names |
Tensor.frac_() |
None |
Tensor.ge() , torch.ge() |
Unifies names from inputs |
Tensor.get_device() ,torch.get_device() |
None |
Tensor.grad |
None |
Tensor.gt() , torch.gt() |
Unifies names from inputs |
Tensor.half() |
Keeps input names |
Tensor.has_names() |
See documentation |
Tensor.index_fill() ,torch.index_fill() |
Keeps input names |
Tensor.index_fill_() |
None |
Tensor.int() |
Keeps input names |
Tensor.is_contiguous() |
None |
Tensor.is_cuda |
None |
Tensor.is_floating_point() , torch.is_floating_point() |
None |
Tensor.is_leaf |
None |
Tensor.is_pinned() |
None |
Tensor.is_shared() |
None |
Tensor.is_signed() ,torch.is_signed() |
None |
Tensor.is_sparse |
None |
torch.is_tensor() |
None |
Tensor.item() |
None |
Tensor.kthvalue() , torch.kthvalue() |
移除尺寸 |
Tensor.le() , torch.le() |
Unifies names from inputs |
Tensor.log() , torch.log() |
Keeps input names |
Tensor.log10() , torch.log10() |
Keeps input names |
Tensor.log10_() |
None |
Tensor.log1p() , torch.log1p() |
Keeps input names |
Tensor.log1p_() |
None |
Tensor.log2() , torch.log2() |
Keeps input names |
Tensor.log2_() |
None |
Tensor.log_() |
None |
Tensor.log_normal_() |
None |
Tensor.logical_not() , torch.logical_not() |
Keeps input names |
Tensor.logical_not_() |
None |
Tensor.logsumexp() , torch.logsumexp() |
Removes dimensions |
Tensor.long() |
Keeps input names |
Tensor.lt() , torch.lt() |
Unifies names from inputs |
torch.manual_seed() |
None |
Tensor.masked_fill() ,torch.masked_fill() |
Keeps input names |
Tensor.masked_fill_() |
None |
Tensor.masked_select() , torch.masked_select() |
将遮罩对齐到输入,然后 unified_names_from_input_tensors |
Tensor.matmul() , torch.matmul() |
Contracts away dims |
Tensor.mean() , torch.mean() |
Removes dimensions |
Tensor.median() , torch.median() |
Removes dimensions |
Tensor.mm() , torch.mm() |
Contracts away dims |
Tensor.mode() , torch.mode() |
Removes dimensions |
Tensor.mul() , torch.mul() |
Unifies names from inputs |
Tensor.mul_() |
Unifies names from inputs |
Tensor.mv() , torch.mv() |
Contracts away dims |
Tensor.names |
See documentation |
Tensor.narrow() , torch.narrow() |
Keeps input names |
Tensor.ndim |
None |
Tensor.ndimension() |
None |
Tensor.ne() , torch.ne() |
Unifies names from inputs |
Tensor.neg() , torch.neg() |
Keeps input names |
Tensor.neg_() |
None |
torch.normal() |
Keeps input names |
Tensor.normal_() |
None |
Tensor.numel() , torch.numel() |
None |
torch.ones() |
Factory functions |
Tensor.pow() , torch.pow() |
Unifies names from inputs |
Tensor.pow_() |
None |
Tensor.prod() , torch.prod() |
Removes dimensions |
torch.rand() |
Factory functions |
torch.rand() |
Factory functions |
torch.randn() |
Factory functions |
torch.randn() |
Factory functions |
Tensor.random_() |
None |
Tensor.reciprocal() , torch.reciprocal() |
Keeps input names |
Tensor.reciprocal_() |
None |
Tensor.refine_names() |
See documentation |
Tensor.register_hook() |
None |
Tensor.rename() |
See documentation |
Tensor.rename_() |
See documentation |
Tensor.requires_grad |
None |
Tensor.requires_grad_() |
None |
Tensor.resize_() |
只允许不改变形状的调整大小 |
Tensor.resize_as_() |
Only allow resizes that do not change shape |
Tensor.round() , torch.round() |
Keeps input names |
Tensor.round_() |
None |
Tensor.rsqrt() , torch.rsqrt() |
Keeps input names |
Tensor.rsqrt_() |
None |
Tensor.select() ,torch.select() |
Removes dimensions |
Tensor.short() |
Keeps input names |
Tensor.sigmoid() , torch.sigmoid() |
Keeps input names |
Tensor.sigmoid_() |
None |
Tensor.sign() , torch.sign() |
Keeps input names |
Tensor.sign_() |
None |
Tensor.sin() , torch.sin() |
Keeps input names |
Tensor.sin_() |
None |
Tensor.sinh() , torch.sinh() |
Keeps input names |
Tensor.sinh_() |
None |
Tensor.size() |
None |
Tensor.split() , torch.split() |
Keeps input names |
Tensor.sqrt() , torch.sqrt() |
Keeps input names |
Tensor.sqrt_() |
None |
Tensor.squeeze() , torch.squeeze() |
Removes dimensions |
Tensor.std() , torch.std() |
Removes dimensions |
torch.std_mean() |
Removes dimensions |
Tensor.stride() |
None |
Tensor.sub() ,torch.sub() |
Unifies names from inputs |
Tensor.sub_() |
Unifies names from inputs |
Tensor.sum() , torch.sum() |
Removes dimensions |
Tensor.tan() , torch.tan() |
Keeps input names |
Tensor.tan_() |
None |
Tensor.tanh() , torch.tanh() |
Keeps input names |
Tensor.tanh_() |
None |
torch.tensor() |
Factory functions |
Tensor.to() |
Keeps input names |
Tensor.topk() , torch.topk() |
Removes dimensions |
Tensor.transpose() , torch.transpose() |
排列尺寸 |
Tensor.trunc() , torch.trunc() |
Keeps input names |
Tensor.trunc_() |
None |
Tensor.type() |
None |
Tensor.type_as() |
Keeps input names |
Tensor.unbind() , torch.unbind() |
Removes dimensions |
Tensor.unflatten() |
See documentation |
Tensor.uniform_() |
None |
Tensor.var() , torch.var() |
Removes dimensions |
torch.var_mean() | Removes dimensions |
Tensor.zero_() |
None |
torch.zeros() |
Factory functions |
保留输入名称
所有逐点一元函数以及其他一些一元函数都遵循此规则。
- 检查姓名:无
- 传播名称:输入张量的名称会传播到输出。
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')
移除尺寸
所有缩小操作,例如 sum()
,都会通过缩小所需尺寸来删除尺寸。 select()
和 squeeze()
等其他操作会删除尺寸。
只要有人可以将整数维度索引传递给运算符,就可以传递维度名称。 包含维索引列表的函数也可以包含维名称列表。
- 检查名称:如果
dim
或dims
作为名称列表传入,请检查self
中是否存在这些名称。 - 传播名称:如果在输出张量中不存在
dim
或dims
指定的输入张量的尺寸,则这些尺寸的相应名称不会出现在output.names
中。
>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')
## Reduction ops with keepdim=True don't actually remove dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')
统一输入中的名称
所有二进制算术运算都遵循此规则。 广播操作仍然从右侧进行位置广播,以保持与未命名张量的兼容性。 要通过名称执行显式广播,请使用 Tensor.align_as()
。
- 检查名称:所有名称都必须从右侧位置匹配。 即,在
tensor + other
中,对于(-min(tensor.dim(), other.dim()) + 1, -1]
中的所有i
,match(tensor.names[i], other.names[i])
必须为 true。 - 检查名称:此外,所有命名的尺寸必须从右对齐。 在匹配期间,如果我们将命名尺寸
A
与未命名尺寸None
匹配,则A
不得出现在具有未命名尺寸的张量中。 - 传播名称:从两个张量的右边开始统一名称对,以产生输出名称。
例如,
## tensor: Tensor[ N, None]
## other: Tensor[None, C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')
检查姓名:
match(tensor.names[-1], other.names[-1])
是True
match(tensor.names[-2], tensor.names[-2])
是True
- 由于我们将
tensor
中的None
与'C'
匹配,因此请确保tensor
中不存在'C'
。 - 检查以确保
other
中不存在'N'
(不存在)。
最后,使用[unify('N', None), unify(None, 'C')] = ['N', 'C']
计算输出名称
更多示例:
## Dimensions don't match from the right:
## tensor: Tensor[N, C]
## other: Tensor[ N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Error when attempting to broadcast dims ['N', 'C'] and dims
['N']: dim 'C' and dim 'N' are at the same position from the right but do
not match.
## Dimensions aren't aligned when matching tensor.names[-1] and other.names[-1]:
## tensor: Tensor[N, None]
## other: Tensor[ N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: Misaligned dims when attempting to broadcast dims ['N'] and
dims ['N', None]: dim 'N' appears in a different position from the right
across both lists.
注意
在最后两个示例中,可以通过名称对齐张量,然后执行加法。 使用 Tensor.align_as()
按名称对齐张量,或使用 Tensor.align_to()
将张量对齐到自定义尺寸顺序。
排列尺寸
某些操作(例如 Tensor.t()
)会置换尺寸顺序。 维度名称附加到各个维度,因此也可以排列。
如果操作员输入位置索引dim
,它也可以采用尺寸名称作为dim
。
- 检查名称:如果将
dim
作为名称传递,请检查其是否在张量中存在。 - 传播名称:以与要排列的维相同的方式排列维名称。
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')
收缩消失
矩阵乘法函数遵循此方法的某些变体。 让我们先通过 torch.mm()
,然后概括一下批矩阵乘法的规则。
对于torch.mm(tensor, other)
:
- Check names: None
- 传播名称:结果名称为
(tensor.names[-2], other.names[-1])
。
>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')
本质上,矩阵乘法在二维上执行点积运算,使它们折叠。 当两个张量矩阵相乘时,收缩尺寸消失,并且不出现在输出张量中。
torch.mv()
, torch.dot()
的工作方式类似:名称推断不会检查输入名称,并且会删除点积所涉及的尺寸:
>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)
现在,让我们看一下torch.matmul(tensor, other)
。 假设tensor.dim() >= 2
和other.dim() >= 2
。
- 检查名称:检查输入的批次尺寸是否对齐并可以广播。 请参见统一输入的名称,以了解对齐输入的含义。
- 传播名称:结果名称是通过统一批次尺寸并删除合同规定的尺寸获得的:
unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])
。
例子:
## Batch matrix multiply of matrices Tensor['C', 'D'] and Tensor['E', 'F'].
## 'A', 'B' are batch dimensions.
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')
最后,还有许多功能的融合add
版本。 即 addmm()
和 addmv()
。 这些被视为构成 mm()
的名称推断和 add()
的命名推断。
工厂功能
现在,工厂函数采用新的names
参数,该参数将名称与每个维度相关联。
>>> torch.zeros(2, 3, names=('N', 'C'))
tensor([[0., 0., 0.],
[0., 0., 0.]], names=('N', 'C'))
输出功能和就地变型
指定为out=
张量的张量具有以下行为:
- 如果没有命名维,则将从操作中计算出的名称传播到其中。
- 如果它具有任何命名维,则从该操作计算出的名称必须与现有名称完全相同。 否则,操作错误。
所有就地方法都会将输入修改为具有与根据名称推断计算出的名称相同的名称。 例如,
>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)
>>> x += y
>>> x.names
('N', 'C')
更多建议: