Loading...

参考链接


torch.cat

torch.cat(tensors, dim=0, *, out=None) → Tensor
将给定序列中的 seq 张量在指定维度上进行连接。所有张量必须在连接维度之外具有相同的形状,或者是空张量。

torch.cat() 可以被视为 torch.split()torch.chunk() 的逆操作。

参数

  • tensors(张量序列)- 相同类型的张量的任何 Python 序列。提供的非空张量必须在 cat 维度上具有相同的形状。

  • dim(int, 可选默认为0)- 张量连接的维度

关键字参数

  • out(Tensor, 可选)- 输出张量。

与torch.stack的区别

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

# 创建两个形状为 [2, 3] 的张量
tensor1 = torch.randn(2, 3)
tensor2 = torch.randn(2, 3)

# 使用 torch.cat 在第一个维度(dim=0)上连接张量
cat_result = torch.cat((tensor1, tensor2), dim=0)

# 使用 torch.stack 在一个新的维度(dim=0)上堆叠张量
stack_result = torch.stack((tensor1, tensor2), dim=0)

# 打印结果和它们的形状
#print("Result of torch.cat:")
#print(cat_result)
print("Shape of cat_result:", cat_result.shape)

#print("\nResult of torch.stack:")
#print(stack_result)
print("Shape of stack_result:", stack_result.shape)

输出结果:

1
2
Shape of cat_result: torch.Size([4, 3])
Shape of stack_result: torch.Size([2, 2, 3])