pytorch怎么改变tensor维度

1413
2024/12/31 6:32:48
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中,可以使用unsqueeze()squeeze()方法来改变Tensor的维度。

unsqueeze()方法可以在指定位置插入一个新的维度,其值都为1。例如,如果你有一个形状为(3,)的Tensor,你可以使用unsqueeze()方法将其形状变为(1, 3):

import torch

x = torch.tensor([1, 2, 3])
y = x.unsqueeze(0)
print(y.shape)  # 输出(1, 3)

squeeze()方法可以删除指定位置的维度,其大小为1。例如,如果你有一个形状为(1, 3)的Tensor,你可以使用squeeze()方法将其形状变为(3,):

import torch

x = torch.tensor([[1, 2, 3]])
y = x.squeeze(0)
print(y.shape)  # 输出(3,)

注意,unsqueeze()squeeze()方法都不会改变Tensor中的数据,只会改变其形状。

辰迅云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读: PyTorch PyG怎样优化模型结构