torch.cat()
torch.cat(tensors, dim=0, *, out=None) → Tensor
描述在给定维度连接给定的张量序列。所有张量必须要么具有相同的形状(要连接的维度除外),要么为空。
参数tensors
(张量序列) – 相同类型张量的任何python序列。提供的非空张量必须具有相同的形状,要拼接的维度除外。dim
(int, 可选) – 张量连接的维度;默认是0
。
>>> import torch
>>> a = torch.rand(2, 3)
>>> b = torch.rand(2, 3)
>>> c = torch.cat((a, b), dim=0)
>>> c.shape
torch.Size([4, 3])
>>> c = torch.cat((a, b), dim=1)
>>> c.shape
torch.Size([2, 6])
torch.stack()
torch.stack(tensors, dim=0, *, out=None) → Tensor
描述沿着一个新的维度连接一个张量序列。所有张量都必须具有相同的大小。
参数tensors
(张量序列) – 要连接的张量序列。dim
(int) – 插入的维度。必须在0
和连接张量的维数之间(包含)。
>>> import torch
>>> a = torch.rand(1, 3)
>>> b = torch.rand(1, 3)
>>> c = torch.stack((a, b), dim=0)
>>> c.shape
torch.Size([2, 1, 3])
>>> d = torch.stack((a, b), dim=1)
>>> d.shape
torch.Size([1, 2, 3])
>>> e = torch.stack((a, b), dim=2)
>>> e.shape
torch.Size([1, 3, 2])
引用参考
https://pytorch.org/docs/stable/generated/torch.cat.html https://pytorch.org/docs/stable/generated/torch.stack.html