参考链接
torch.cat
torch.cat(tensors, dim=0, *, out=None) → Tensor
将给定序列中的 seq 张量在指定维度上进行连接。所有张量必须在连接维度之外具有相同的形状,或者是空张量。
torch.cat() 可以被视为 torch.split() 和 torch.chunk() 的逆操作。
参数
-
tensors(张量序列)- 相同类型的张量的任何 Python 序列。提供的非空张量必须在cat维度上具有相同的形状。 -
dim(int, 可选默认为0)- 张量连接的维度
关键字参数
out(Tensor, 可选)- 输出张量。
与torch.stack的区别
代码如下:
1 | import torch |
输出结果:
1 | Shape of cat_result: torch.Size([4, 3]) |