pytorch中张量的维度操作
引言
对于我本人来说,在阅读深度学习项目的源码时,分析张量的维度是一件让我十分头疼的事情。在我自己写程序时也会遇到因为张量维度不匹配而报错的问题,这时的我总会一头雾水,对于如何debug无从下手。更令人悲痛的是,当你把维度不匹配的报错交给AI企图让它来解决问题时,AI往往会一通维度操作猛如虎,最后运行一看还是跟之前相同的报错。
之前我零零散散地查阅过一些关于操作维度的函数的用法,但没系统地整理过,所以今天在这里整理一下。当然,如果真的遇到了维度不匹配的报错,那还是得从头一点点排查。只不过,熟记了这些张量操作的函数产生的效果后,也许debug会比之前顺利许多。
在计算机视觉领域,图像张量常设为四维,即
1 | [B, C, H, W] |
B即batch_size,是一次批量处理图像的数量C即channels,为图像通道数。RGB图像的通道数为3,经过某些操作(如卷积)通道数会发生改变H即height,为一张图像的高W即width,为一张图像的宽
维度操作
调整形状
tensor.view()返回一个具有相同数据但不同形状的新张量,要求内存必须是连续的。
tensor.reshape()效果类似于
view(),但不要求数据连续,因此更推荐使用。
1 | x = torch.randn((2, 3, 5, 7)) |
增减维度
tensor.unsqueeze(dim)在
dim位置增加一个维度,大小为1。(squeeze的意思为“压缩”,unsqueeze顾名思义就是伸展一个维度)1
2
3
4
5print(x.shape) # torch.Size([2, 3, 5, 7])
x_prime = x.unsqueeze(0)
print(x_prime.shape) # torch.Size([1, 2, 3, 5, 7])
x_prime = x_prime.unsqueeze(2)
print(x_prime.shape) # torch.Size([1, 2, 1, 3, 5, 7])tensor.squeeze(dim=None)当
dim参数缺省时,默认移除所有大小为1的维度。指定维度时,若该维度大小为1则将其移除,否则无影响。1
2
3
4
5
6
7
8print(x_prime.shape) # torch.Size([1, 2, 1, 3, 5, 7])
y_prime = x_prime.squeeze()
z1_prime = x_prime.squeeze(0)
z2_prime = x_prime.squeeze(1)
z3_prime = x_prime.squeeze(2)
print(y_prime.shape) # torch.Size([2, 3, 5, 7])
print(z1_prime.shape, z2_prime.shape, z3_prime.shape)
# torch.Size([2, 1, 3, 5, 7]) torch.Size([1, 2, 1, 3, 5, 7]) torch.Size([1, 2, 3, 5, 7])
交换维度
tensor.permute(dims)交换多个维度并返回新张量(不修改原始数据)。
tensor.transpose(dim0, dim1)交换传入的两个特定的维度。
1 | print(y.shape) # torch.Size([6, 5, 7]) |
维度展开与合并
tensor.flatten(start_dim=0. end_dim=-1)将
start_dim和end_dim之间的维度合并为一个。参数缺省时分别默认为第一个和最后一个维度。1
2
3
4
5print(z.shape) # torch.Size([6, 5, 7])
z_flatten = z.flatten()
print(z_flatten.shape) # torch.Size([210])
z_flatten = z.flatten(0, 1)
print(z_flatten.shape) # torch.Size([30, 7])tensor.unflatten(dim, sizes)将
dim维度拆分成sizes指定的多个维度。1
2
3print(z_flatten.shape) # torch.Size([30, 7])
z_unflatten = z_flatten.unflatten(0, (3, 10))
print(z_unflatten.shape) # torch.Size([3, 10, 7])
张量扩展
tensor.repeat(*sizes)沿指定维度重复张量,传入函数的参数数量要和张量的维数一样。作用于每个维度,独立复制元素。
1
2
3print(x.shape) # torch.Size([2, 3, 5, 7])
x_repeat = x.repeat(4, 2, 1, 1)
print(x_repeat.shape) # torch.Size([8, 6, 5, 7])tensor.tiles(sizes)功能类似
repeat(),作用于整个张量,复制整个块。1
2
3print(x.shape) # torch.Size([2, 3, 5, 7])
x_tile = x.tile(5, 3, 4, 1)
print(x_tile.shape) # torch.Size([10, 9, 20, 7])tensor.expand(*sizes)主要用于无额外内存开销地扩展张量,可以让一个张量沿着指定的维度进行广播,但不会实际复制数据。通常用于在计算时匹配张量的形状,而不增加额外的内存消耗,只是改变张量的视图。传入的
*sizes相当于想要的形状。注意:
expand()只能扩展大小为1的维度,否则会报错。1
2
3x = torch.randn(1, 2, 3)
x_expanded = x.expand(4, 2, 3)
print(x_expanded.shape) # torch.Size([4, 2, 3])
| 函数 | 是否额外占用内存 | 作用 |
|---|---|---|
expand() |
否 | 仅改变张量视图,广播显示 |
repeat() |
是 | 按元素复制张量 |
tile() |
是 | 按整体块复制张量 |
张量拼接
torch.cat(tensors, dim=0)沿
dim维度拼接多个张量,这些张量在除了dim以外的其他维度形状必须一致。所得的新张量维数和它们相同。1
2
3
4
5
6
7x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
z1 = torch.cat((x, y), dim=0)
z2 = torch.cat((x, y), dim=1)
z3 = torch.cat((x, y), dim=2)
print(z1.shape, z2.shape, z3.shape)
# torch.Size([4, 3, 5]) torch.Size([2, 6, 5]) torch.Size([2, 3, 10])torch.stack(tensors, dim=0)在指定
dim维度上新增一个维度,然后在该维度上堆叠张量。所有被堆叠的张量的所有维度形状都必须相同。可以理解为
stack = unsqueeze + cat,即对所有张量先在dim处新增一个维度,再在该维度上进行cat操作。1
2
3
4
5
6
7print(x.shape) # torch.Size([2, 3, 5])
print(y.shape) # torch.Size([2, 3, 5])
z4 = torch.stack((x, y), dim=0)
z5 = torch.stack((x, y), dim=1)
z6 = torch.stack((x, y), dim=2)
print(z4.shape, z5.shape, z6.shape)
# torch.Size([2, 2, 3, 5]) torch.Size([2, 2, 3, 5]) torch.Size([2, 3, 2, 5])
总结:
| 操作 | 方法 | 作用 |
|---|---|---|
| 调整形状 | view() / reshape() |
改变形状 |
| 增加维度 | unsqueeze() |
增加维度 |
| 减少维度 | squeeze() |
移除大小为 1 的维度 |
| 交换维度 | permute() / transpose() |
交换多个维度 / 交换两个维度 |
| 展开/合并 | flatten() / unflatten() |
合并维度 / 拆分维度 |
| 重复数据 | repeat() / tile() / expand() |
沿维度复制数据 |
| 拼接数据 | cat() / stack() |
拼接多个张量 |
张量运算
爱因斯坦求和约定
einsum
torch.einsum提供了一种基于爱因斯坦求和约定的简洁表达方式,能够高效执行各种张量操作,包括矩阵乘法、内积、外积、转置等。
语法为:
1 | torch.einsum(equation, *operands) |
其中: - equation:字符串,定义操作规则 -
operands:一个或者多个张量
爱因斯坦求和约定的内容为: - 若索引重复,表示在该维度上求和 - 若索引不重复,表示保留该维度 - 若索引在输出部分缺失,表示该维度被求和消去
常见应用: 1. 矩阵乘法
1 | A = torch.randn(2, 3) |
1 | a = torch.tensor([1, 2, 3]) |
1 | a = torch.tensor([1, 2, 3]) |
1 | A = torch.randn(5, 2, 3) # 5个 (2x3) 矩阵 |
1 | A = torch.randn(3, 3) |
1 | A = torch.randn(2, 3) |
张量乘法 matmul
matmul的行为取决于张量的维度。低维张量(如标量、向量和矩阵)的运算较为简单:
A维度 |
B维度 |
torch.matmul(A, B)维度 |
|---|---|---|
| 0(标量) | 0 | 0 |
| 1(向量) | 1 | 0(点积) |
| 2(矩阵) | 1 | 1 |
| 2 | 2 | 2 |
对于高维张量(三维及以上),torch.matmul(A, B)采用批量矩阵乘法。即从右往左取A和B的最后两个维度进行标准矩阵乘法,前面的批量维度自动对齐(需要保证可广播)
1 | a = torch.randn(1, 4, 2, 3) |
同时,pytorch还提供了torch.bmm()用于严格的3D批量矩阵乘法,即仅适用于(b, m, k) @ (b, k, n)。
- Title: pytorch中张量的维度操作
- Author: Jachin Zhang
- Created at : 2025-03-03 16:50:07
- Updated at : 2025-03-03 21:22:20
- Link: https://jachinzhang1.github.io/2025/03/03/tensor-dimension-operations/
- License: This work is licensed under CC BY-NC-SA 4.0.