PyTorch中repeat、tile与repeat_interleave的区别
torch.Tensor.repeat
repeat
可以形象地理解为将已有的张量多次重复以组成 “分块矩阵”。
import torch
""" Example 1 """
t = torch.arange(3)
print(t.repeat((2, )))
# tensor([0, 1, 2, 0, 1, 2])
print(t.repeat((2, 2)))
# tensor([[0, 1, 2, 0, 1, 2],
# [0, 1, 2, 0, 1, 2]])
""" Example 2 """
t = torch.arange(4).reshape(2, 2)
print(t.repeat((2, )))
# RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
print(t.repeat((2, 2)))
# tensor([[0, 1, 0, 1],
# [2, 3, 2, 3],
# [0, 1, 0, 1],
# [2, 3, 2, 3]])
print(t.repeat((2, 3, 4)))
# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3]],
#
# [[0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3]]])
可以看出要 repeat
的维度不能低于张量本身的维度。
torch.Tensor.tile
大部分情况下,tile
与 repeat
的作用相同,如下:
""" Example 1 """
t = torch.arange(3)
print(t.tile((2, )))
# tensor([0, 1, 2, 0, 1, 2])
print(t.tile((2, 2)))
# tensor([[0, 1, 2, 0, 1, 2],
# [0, 1, 2, 0, 1, 2]])
""" Example 2 """
t = torch.arange(4).reshape(2, 2)
print(t.tile((2, )))
# tensor([[0, 1, 0, 1],
# [2, 3, 2, 3]])
print(t.tile((2, 2)))
# tensor([[0, 1, 0, 1],
# [2, 3, 2, 3],
# [0, 1, 0, 1],
# [2, 3, 2, 3]])
print(t.tile((2, 3, 4)))
# tensor([[[0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3]],
#
# [[0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3],
# [0, 1, 0, 1, 0, 1, 0, 1],
# [2, 3, 2, 3, 2, 3, 2, 3]]])
与 repeat
不同的是,当要重复的维度低于张量的维度时,tile
也能够处理,此时 tile
会使用前置
1
1
1 自动补齐维度。
torch.Tensor.repeat_interleave
之前提到的 repeat
和 tile
都是重复整个张量,而这次的 repeat_interleave
则是重复张量中的元素。
参数如下:
torch.Tensor.repeat_interleave(repeats, dim=None)
-
repeats
:代表张量中每个元素将要重复的次数。可以为整数或张量; -
dim
:决定了沿哪一个轴去重复数字。默认情况下会将输入展平再进行重复,最后输出展平的张量。
""" Example 1 """
t = torch.arange(3)
print(t.repeat_interleave(repeats=3))
# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])
""" Example 2 """
t = torch.arange(4).reshape(2, 2)
print(t.repeat_interleave(repeats=3))
# tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
print(t.repeat_interleave(repeats=3, dim=0))
# tensor([[0, 1],
# [0, 1],
# [0, 1],
# [2, 3],
# [2, 3],
# [2, 3]])
print(t.repeat_interleave(repeats=3, dim=1))
# tensor([[0, 0, 0, 1, 1, 1],
# [2, 2, 2, 3, 3, 3]])
""" Example 3 """
t = torch.arange(4).reshape(2, 2)
print(t.repeat_interleave(repeats=torch.tensor([2, 3]), dim=0)) # t的第一行重复2次,第2行重复3次
# tensor([[0, 1],
# [0, 1],
# [2, 3],
# [2, 3],
# [2, 3]])
print(t.repeat_interleave(repeats=torch.tensor([3, 2]), dim=1)) # t的第一列重复3次,第2列重复2次
# tensor([[0, 0, 0, 1, 1],
# [2, 2, 2, 3, 3]])
本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
THE END
二维码