pytorch中张量的维度操作

Jachin Zhang

引言

对于我本人来说,在阅读深度学习项目的源码时,分析张量的维度是一件让我十分头疼的事情。在我自己写程序时也会遇到因为张量维度不匹配而报错的问题,这时的我总会一头雾水,对于如何debug无从下手。更令人悲痛的是,当你把维度不匹配的报错交给AI企图让它来解决问题时,AI往往会一通维度操作猛如虎,最后运行一看还是跟之前相同的报错。

之前我零零散散地查阅过一些关于操作维度的函数的用法,但没系统地整理过,所以今天在这里整理一下。当然,如果真的遇到了维度不匹配的报错,那还是得从头一点点排查。只不过,熟记了这些张量操作的函数产生的效果后,也许debug会比之前顺利许多。

在计算机视觉领域,图像张量常设为四维,即

1
[B, C, H, W]
  • Bbatch_size,是一次批量处理图像的数量
  • Cchannels,为图像通道数。RGB图像的通道数为3,经过某些操作(如卷积)通道数会发生改变
  • Hheight,为一张图像的高
  • Wwidth,为一张图像的宽

维度操作

调整形状

  • tensor.view()

    返回一个具有相同数据但不同形状的新张量,要求内存必须是连续的

  • tensor.reshape()

    效果类似于view(),但不要求数据连续,因此更推荐使用。

1
2
3
4
5
6
x = torch.randn((2, 3, 5, 7))
print(x.shape) # torch.Size([2, 3, 5, 7])
y = x.view(6, 5, 7)
print(y.shape) # torch.Size([6, 5, 7])
z = x.reshape(6, 5, 7)
print(z.shape) # torch.Size([6, 5, 7])

增减维度

  • tensor.unsqueeze(dim)

    dim位置增加一个维度,大小为1。(squeeze的意思为“压缩”,unsqueeze顾名思义就是伸展一个维度)

    1
    2
    3
    4
    5
    print(x.shape)                # torch.Size([2, 3, 5, 7])
    x_prime = x.unsqueeze(0)
    print(x_prime.shape) # torch.Size([1, 2, 3, 5, 7])
    x_prime = x_prime.unsqueeze(2)
    print(x_prime.shape) # torch.Size([1, 2, 1, 3, 5, 7])

  • tensor.squeeze(dim=None)

    dim参数缺省时,默认移除所有大小为1的维度。指定维度时,若该维度大小为1则将其移除,否则无影响。

    1
    2
    3
    4
    5
    6
    7
    8
    print(x_prime.shape)      # torch.Size([1, 2, 1, 3, 5, 7])
    y_prime = x_prime.squeeze()
    z1_prime = x_prime.squeeze(0)
    z2_prime = x_prime.squeeze(1)
    z3_prime = x_prime.squeeze(2)
    print(y_prime.shape) # torch.Size([2, 3, 5, 7])
    print(z1_prime.shape, z2_prime.shape, z3_prime.shape)
    # torch.Size([2, 1, 3, 5, 7]) torch.Size([1, 2, 1, 3, 5, 7]) torch.Size([1, 2, 3, 5, 7])

交换维度

  • tensor.permute(dims)

    交换多个维度并返回新张量(不修改原始数据)。

  • tensor.transpose(dim0, dim1)

    交换传入的两个特定的维度。

1
2
3
4
5
print(y.shape)              # torch.Size([6, 5, 7])
y_prime = y.permute(1, 2, 0)
print(y_prime.shape) # torch.Size([5, 7, 6])
y_prime = y.transpose(1, 2)
print(y_prime.shape) # torch.Size([6, 7, 5])

维度展开与合并

  • tensor.flatten(start_dim=0. end_dim=-1)

    start_dimend_dim之间的维度合并为一个。参数缺省时分别默认为第一个和最后一个维度。

    1
    2
    3
    4
    5
    print(z.shape)            # torch.Size([6, 5, 7])
    z_flatten = z.flatten()
    print(z_flatten.shape) # torch.Size([210])
    z_flatten = z.flatten(0, 1)
    print(z_flatten.shape) # torch.Size([30, 7])

  • tensor.unflatten(dim, sizes)

    dim维度拆分成sizes指定的多个维度。

    1
    2
    3
    print(z_flatten.shape)    # torch.Size([30, 7])
    z_unflatten = z_flatten.unflatten(0, (3, 10))
    print(z_unflatten.shape) # torch.Size([3, 10, 7])

张量扩展

  • tensor.repeat(*sizes)

    沿指定维度重复张量,传入函数的参数数量要和张量的维数一样。作用于每个维度,独立复制元素。

    1
    2
    3
    print(x.shape)            # torch.Size([2, 3, 5, 7])
    x_repeat = x.repeat(4, 2, 1, 1)
    print(x_repeat.shape) # torch.Size([8, 6, 5, 7])

  • tensor.tiles(sizes)

    功能类似repeat(),作用于整个张量,复制整个块。

    1
    2
    3
    print(x.shape)            # torch.Size([2, 3, 5, 7])
    x_tile = x.tile(5, 3, 4, 1)
    print(x_tile.shape) # torch.Size([10, 9, 20, 7])

  • tensor.expand(*sizes)

    主要用于无额外内存开销地扩展张量,可以让一个张量沿着指定的维度进行广播,但不会实际复制数据。通常用于在计算时匹配张量的形状,而不增加额外的内存消耗,只是改变张量的视图。传入的*sizes相当于想要的形状。

    注意:expand()只能扩展大小为1的维度,否则会报错。

    1
    2
    3
    x = torch.randn(1, 2, 3)
    x_expanded = x.expand(4, 2, 3)
    print(x_expanded.shape) # torch.Size([4, 2, 3])

函数 是否额外占用内存 作用
expand() 仅改变张量视图,广播显示
repeat() 按元素复制张量
tile() 按整体块复制张量

张量拼接

  • torch.cat(tensors, dim=0)

    沿dim维度拼接多个张量,这些张量在除了dim以外的其他维度形状必须一致。所得的新张量维数和它们相同。

    1
    2
    3
    4
    5
    6
    7
    x = torch.randn(2, 3, 5)
    y = torch.randn(2, 3, 5)
    z1 = torch.cat((x, y), dim=0)
    z2 = torch.cat((x, y), dim=1)
    z3 = torch.cat((x, y), dim=2)
    print(z1.shape, z2.shape, z3.shape)
    # torch.Size([4, 3, 5]) torch.Size([2, 6, 5]) torch.Size([2, 3, 10])

  • torch.stack(tensors, dim=0)

    在指定dim维度上新增一个维度,然后在该维度上堆叠张量。所有被堆叠的张量的所有维度形状都必须相同。

    可以理解为stack = unsqueeze + cat,即对所有张量先在dim处新增一个维度,再在该维度上进行cat操作。

    1
    2
    3
    4
    5
    6
    7
    print(x.shape)            # torch.Size([2, 3, 5])
    print(y.shape) # torch.Size([2, 3, 5])
    z4 = torch.stack((x, y), dim=0)
    z5 = torch.stack((x, y), dim=1)
    z6 = torch.stack((x, y), dim=2)
    print(z4.shape, z5.shape, z6.shape)
    # torch.Size([2, 2, 3, 5]) torch.Size([2, 2, 3, 5]) torch.Size([2, 3, 2, 5])

总结:

操作 方法 作用
调整形状 view() / reshape() 改变形状
增加维度 unsqueeze() 增加维度
减少维度 squeeze() 移除大小为 1 的维度
交换维度 permute() / transpose() 交换多个维度 / 交换两个维度
展开/合并 flatten() / unflatten() 合并维度 / 拆分维度
重复数据 repeat() / tile() / expand() 沿维度复制数据
拼接数据 cat() / stack() 拼接多个张量

张量运算

爱因斯坦求和约定 einsum

torch.einsum提供了一种基于爱因斯坦求和约定的简洁表达方式,能够高效执行各种张量操作,包括矩阵乘法、内积、外积、转置等。

语法为:

1
torch.einsum(equation, *operands)

其中: - equation:字符串,定义操作规则 - operands:一个或者多个张量

爱因斯坦求和约定的内容为: - 若索引重复,表示在该维度上求和 - 若索引不重复,表示保留该维度 - 若索引在输出部分缺失,表示该维度被求和消去

常见应用: 1. 矩阵乘法

1
2
3
4
5
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# ij 表示 A 形状,jk 表示 B 形状。j是重复索引,在该维度上求和
C = torch.einsum('ij,jk->ik', A, B) # 等价于 torch.matmul(A, B)
print(C.shape) # (2, 4)
2. 向量点积
1
2
3
4
5
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# i,i-> 说明 i 在两个向量中都存在,按元素相乘后求和,得到一个标量
dot_product = torch.einsum('i,i->', a, b) # 等价于 torch.dot(a, b)
print(dot_product) # 32 (1×4 + 2×5 + 3×6)
3. 向量外积
1
2
3
4
5
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
# i,j->ij 表示保留 i 和 j 维度,不求和,得到外积,形状 (3,3)
outer_product = torch.einsum('i,j->ij', a, b) # 等价于 a.unsqueeze(1) * b.unsqueeze(0)
print(outer_product)
4. 批量矩阵乘法
1
2
3
4
5
A = torch.randn(5, 2, 3)  # 5个 (2x3) 矩阵
B = torch.randn(5, 3, 4) # 5个 (3x4) 矩阵

C = torch.einsum('bij,bjk->bik', A, B) # 等价于 torch.matmul(A, B)
print(C.shape) # (5, 2, 4)
5. 计算矩阵的迹
1
2
3
4
A = torch.randn(3, 3)
# ii-> 表示对角线求和,得到矩阵的迹
trace = torch.einsum('ii->', A) # 等价于 torch.trace(A)
print(trace)
6. 矩阵转置
1
2
3
4
A = torch.randn(2, 3)
# ij->ji 交换 i 和 j,实现转置
A_T = torch.einsum('ij->ji', A) # 等价于 A.T
print(A_T.shape) # (3,2)

张量乘法 matmul

matmul的行为取决于张量的维度。低维张量(如标量、向量和矩阵)的运算较为简单:

A维度 B维度 torch.matmul(A, B)维度
0(标量) 0 0
1(向量) 1 0(点积)
2(矩阵) 1 1
2 2 2

对于高维张量(三维及以上),torch.matmul(A, B)采用批量矩阵乘法。即从右往左取AB的最后两个维度进行标准矩阵乘法,前面的批量维度自动对齐(需要保证可广播)

1
2
3
4
5
6
a = torch.randn(1, 4, 2, 3)
b = torch.randn(5, 1, 3, 4)
a_b = torch.matmul(a, b)
print(a.shape) # torch.Size([1, 4, 2, 3])
print(b.shape) # torch.Size([5, 1, 3, 4])
print(a_b.shape) # torch.Size([5, 4, 2, 4])

同时,pytorch还提供了torch.bmm()用于严格的3D批量矩阵乘法,即仅适用于(b, m, k) @ (b, k, n)

  • Title: pytorch中张量的维度操作
  • Author: Jachin Zhang
  • Created at : 2025-03-03 16:50:07
  • Updated at : 2025-03-03 21:22:20
  • Link: https://jachinzhang1.github.io/2025/03/03/tensor-dimension-operations/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments