参考链接
torch.cat
torch.cat(tensors, dim=0, *, out=None) → Tensor
将给定序列中的 seq
张量在指定维度上进行连接。所有张量必须在连接维度之外具有相同的形状,或者是空张量。
torch.cat()
可以被视为 torch.split()
和 torch.chunk()
的逆操作。
参数
-
tensors
(张量序列)- 相同类型的张量的任何 Python 序列。提供的非空张量必须在cat
维度上具有相同的形状。 -
dim
(int, 可选默认为0)- 张量连接的维度
关键字参数
out
(Tensor, 可选)- 输出张量。
与torch.stack的区别
代码如下:
1 | import torch |
输出结果:
1 | Shape of cat_result: torch.Size([4, 3]) |