目录
1. 降维torch.squeeze(input, dim=None, out=None)
简单示例
matplotlib画图示例
2.增维 torch.unsqueeze(input, dim, out=None)
简单示例
3.参考
1. 降维torch.squeeze(input, dim=None, out=None)函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。
- 当给定dim时,那么挤压操作只在给定维度上。即若tensor.size(dim) = 1,则去掉该维度
- 其中squeeze(0)代表若第一维度值为1则去除第一维度
- squeeze(1)代表若第二维度值为1则去除第二维度
- -1,去除最后维度值为1的维度
- 当不给定dim时,将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D)(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)(A×B×C×D)
例如,输入形状为: (A×1×B)(A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)(A×B)。
注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
参数:
- input (Tensor) – 输入张量
- dim (int, optional) – 如果给定,则input只会在给定维度挤压,维度的索引(从0开始)
- out (Tensor, optional) – 输出张量
a = torch.Tensor(1,3)
>>
tensor([[-1.37,4.56,-3.57]])
print a.squeeze(0) #第一个维度大小确实是1,所以可以去除
>>
tensor([-1.37,4.56,-3.57])
print a.squeeze(1) ##第二个维度大小是3,所以不能去除
>>
tensor([[-1.37,4.56,-3.57]])
#例子2
b = torch.Tensor(2,3)
print b
>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])
print b.squeeze(0)##第一个维度大小不是1,所以不能去除
>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])
print b.squeeze(1) ##第二个维度大小是3,所以不能去除
>>
tensor([[-3.17,3.09,1.43],
[0.00,0.00,0.00]])
#例子3
c = torch.Tensor(3,1)
print c
>>
tensor([[-3.54],
[3.09],
[0.00]])
print c.squeeze(0)##第一个维度大小不是1,所以不能去除
>>
tensor([[-3.54],
[3.09],
[0.00]])
print c.squeeze(1)#第二个维度大小确实是1,所以可以去除
>>
tensor([-3.54,3.09,0.00])
matplotlib画图示例
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
#无法正常显示图示案例
squares =np.array([[1,4,9,16,25]])
squares.shape #要显示的数组为可表示1行5列的向量的数组
(1, 5)
plt.plot(squares)
plt.show()
#正常显示图示案例
#通过np.squeeze()函数转换后,要显示的数组变成了秩为1的数组,即(5,)
plt.plot(np.squeeze(squares))
plt.show()
np.squeeze(squares).shape
(5,)
2.增维 torch.unsqueeze(input, dim, out=None)
增加大小为1的维度,也就是返回一个新的张量,对输入的指定位置插入维度 1且必须指明维度
- x = torch.unsqueeze(x, 3) # 在第3个维度上扩展
注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果dim为负,则将会被转化dim+input.dim()+1,例如对于一个(3,2,4)的tensor,其dim可以选择为none,-1,0,1,2
- none:所有元素的max,得到一个max值
- -1:若dim为负,则将被转化为dim+input.dim()+1,即2
- -1+2+1
- 0:最粗粒度的方向,在第1维插入一个维度
- 1:在第2维插入一个维度
- 2:最细粒度的方向,在第3维插入一个维度
- -3:在倒数第3维插入一个维度,在本例子也就是第一维
- 一句话概括:dim越大,越深入,none即所有最小元素参与计算。
参数:
- tensor (Tensor) – 输入张量
- dim (int) – 插入维度的索引(从0开始)
- out (Tensor, optional) – 结果张量
import torch
x = torch.ones(4)
print(x)
print(x.size())
y = torch.unsqueeze(x, 0)
print(y)
print(y.size())
z = torch.unsqueeze(x, 1)
print(z)
print(z.size())
结果
tensor([1., 1., 1., 1.])
torch.Size([4])
tensor([[1., 1., 1., 1.]])
torch.Size([1, 4])
tensor([[1.],
[1.],
[1.],
[1.]])
torch.Size([4, 1])
分析
插入维度之前:
[ 1, 1, 1, 1 ]
在第0维插入一个维度,使其变成(1,4),即在最外层插入一个中括号即可:
[ [ 1, 1, 1, 1 ] ]
在第1维插入一个维度,使其变成(4,1)
[ [1], [1], [1], [1] ]
3.参考pytorch中对维度及其squeeze()、unsqueeze()函数的理解
torch.squeeze()和unsqueeze()
Numpy库学习—squeeze()函数