PyTorch中repeat、tile与repeat_interleave的区别

torch.Tensor.repeat

repeat 可以形象地理解为将已有的张量多次重复以组成 “分块矩阵”

import torch

""" Example 1 """
t = torch.arange(3)
print(t.repeat((2, )))
# tensor([0, 1, 2, 0, 1, 2])
print(t.repeat((2, 2)))
# tensor([[0, 1, 2, 0, 1, 2],
#         [0, 1, 2, 0, 1, 2]])

""" Example 2 """
t = torch.arange(4).reshape(2, 2)
print(t.repeat((2, )))
# RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
print(t.repeat((2, 2)))
# tensor([[0, 1, 0, 1],
#         [2, 3, 2, 3],
#         [0, 1, 0, 1],
#         [2, 3, 2, 3]])
print(t.repeat((2, 3, 4)))
# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3]],
# 
#         [[0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3]]])

可以看出要 repeat 的维度不能低于张量本身的维度。

torch.Tensor.tile

大部分情况下,tilerepeat 的作用相同,如下:

""" Example 1 """
t = torch.arange(3)
print(t.tile((2, )))
# tensor([0, 1, 2, 0, 1, 2])
print(t.tile((2, 2)))
# tensor([[0, 1, 2, 0, 1, 2],
#         [0, 1, 2, 0, 1, 2]])

""" Example 2 """
t = torch.arange(4).reshape(2, 2)
print(t.tile((2, )))
# tensor([[0, 1, 0, 1],
#         [2, 3, 2, 3]])
print(t.tile((2, 2)))
# tensor([[0, 1, 0, 1],
#         [2, 3, 2, 3],
#         [0, 1, 0, 1],
#         [2, 3, 2, 3]])
print(t.tile((2, 3, 4)))
# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3]],
# 
#         [[0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3],
#          [0, 1, 0, 1, 0, 1, 0, 1],
#          [2, 3, 2, 3, 2, 3, 2, 3]]])

repeat 不同的是,当要重复的维度低于张量的维度时,tile 也能够处理,此时 tile 会使用前置

1

1

1 自动补齐维度。

torch.Tensor.repeat_interleave

之前提到的 repeattile 都是重复整个张量,而这次的 repeat_interleave 则是重复张量中的元素

参数如下:

torch.Tensor.repeat_interleave(repeats, dim=None)
  • repeats:代表张量中每个元素将要重复的次数。可以为整数或张量;
  • dim:决定了沿哪一个轴去重复数字。默认情况下会将输入展平再进行重复,最后输出展平的张量。
""" Example 1 """
t = torch.arange(3)
print(t.repeat_interleave(repeats=3))
# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])

""" Example 2 """
t = torch.arange(4).reshape(2, 2)
print(t.repeat_interleave(repeats=3))
# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
print(t.repeat_interleave(repeats=3, dim=0))
# tensor([[0, 1],
#         [0, 1],
#         [0, 1],
#         [2, 3],
#         [2, 3],
#         [2, 3]])
print(t.repeat_interleave(repeats=3, dim=1))
# tensor([[0, 0, 0, 1, 1, 1],
#         [2, 2, 2, 3, 3, 3]])

""" Example 3 """
t = torch.arange(4).reshape(2, 2)
print(t.repeat_interleave(repeats=torch.tensor([2, 3]), dim=0))  # t的第一行重复2次,第2行重复3次
# tensor([[0, 1],
#         [0, 1],
#         [2, 3],
#         [2, 3],
#         [2, 3]])
print(t.repeat_interleave(repeats=torch.tensor([3, 2]), dim=1))  # t的第一列重复3次,第2列重复2次
# tensor([[0, 0, 0, 1, 1],
#         [2, 2, 2, 3, 3]])

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
分享
二维码
< <上一篇
下一篇>>