目录
1. pytorch中的数据类型
1.1 标量——维度为0——用于loss
1.2 向量——维度为1——用于bias、线性输入数据
1.3 二维tensor
1.4 三维tensor
1.5 四维tensor
1.6 其他
2. 创建tensor
3. tensor中的索引与切片
4. tensor维度变换
4.1 view 与 reshape
4.2 Squeeze 与 unsqueeze
4.3 Expand 与 repeat 维度拓展
4.3 转置 .t 只能适用于 2D矩阵
4.4 Transpose 交换维度
4.5 permute 交换维度,指定维度顺序
4.6 Brodcasting 自动扩展
5. tensor的拼接和拆分
5.1 cat
5.2 stack
5.3 split
5.4 chunk
6. tensor的基本运算
6.1 加减乘除
6.2 矩阵相乘
6.3 三维及以上矩阵相乘
6.4 pow
6.5 Exp 及log
6.6 近似函数
6.7 数值剪切clamp
7. tensor的属性统计
7.1 范数 norm
7.2 统计属性 mean sum min max prod argmax argmin
7.3 keepdim
7.4 Top-k
7.5 比较操作 gt eq equal
7.6 高阶操作 Where Gather
引言:pytorch是面向计算的GPU加速库,所以里面的所有操作对象都是tensor(张量)。本文主要介绍pytorch中的数据类型,tensor的创建,索引与切片,维度变换、拼接与拆分、基本数学运算、属性统计等函数功能及示例。(参考:深度学习入门_哔哩哔哩_bilibili)
1. pytorch中的数据类型
数据类型检查:
比如: 10个句子,每句20个单词,每个单词100个feature表示(one hot)
比如: b张照片c个通道,一个图片大小hxw大小
numel 属性 : num of element
4704 = 2x3x28x28
上图中[ ]中的逗号,表示的是维度与维度之间的并列。
tensor中切片,冒号:
a = tensor.randn(4, 3, 28,28)
- 仅一个冒号:表示取当前维度所有值 , 比如 a[:, :, :, :]
- 一个冒号+数字: 如n: 表示从取当前维度n到最后; 如 :n 表示从0取到n-1 , 比如 a[:10, 2:, :, :] 只取前10张图片的2通道和3通道
- 一个冒号+2个数字:如 n:m 表示从n取到m-1, 比如 a[10:20, :, :, :] 取第10张到第19张图片
- 两个冒号 0:10:2 从0到9,隔行采样,采样间隔是2,此处是 0,2,4,6,8。或者 ::2, 表示当前维所有值,从0开始间隔2采样
三个点: … 表示剩余的所有维度
取出指定条件下的值(打平成了一维)
view操作的基本原则: numel不变
特别注意: 数据的存储和维度顺序非常重要,需要时刻记住!
view与reshape区别: reshape是view的升级版。
因为历史上view方法已经约定了共享底层数据内存,返回的Tensor底层数据不会使用新的内存,如果在view中调用了contiguous方法,则可能在返回Tensor底层数据中使用了新的内存,这样打破了之前的约定,破坏了对之前的代码兼容性。为了解决用户使用便捷性问题,PyTorch在0.4版本以后提供了reshape方法,实现了类似于 tensor.contigous().view(*args)的功能,如果不关心底层数据是否使用了新的内存,则使用reshape方法更方便。
减小维度 与 增加维度
注意 -1是指倒数第一位
expand:将维度拓展到指定值
repeat: 将维度拓展指定倍数
使用expand时,拓展前后维度应该一致,且可拓展,比如从1->N可以,从 3-》N则不行,因为怎么复制数据系统不知道。 从1-》N 系统知道直接复制即可。
expand指定拓展后的维度,在原有的内存上进行拓展
repeat的参数,表示每个维度重复的倍数!会申请新的内存,进行拓展,比较占内存
交换维度时,一定要注意:经过view后维度会发生变化,若想确保各维度有意义,一定要跟踪好维度的变换关系,确保前后能对应的上。
维度交换,指定维度的顺序,(内部交换多次transpose)
自动地根据数据,对tensor进行 unsqueeze/expand操作,来插入维度拓展数据,以实现同维度数据之间的计算。
扩展内部规则:
几个例子:
在指定维度上进行合并(前提: 维度数量一样,只有一个维度的值不同)
会增加一个新的维度,新的维度的值等于第一个参数的个数。(要求操作的两个tensor维度数目 、各维度值应完全一致)
按照长度,拆分指定维度。当传入一个list时,按照list中指定的值,进行维度值的拆分,当传入一个常数k时,在指定维度上平均拆分(新的拆分的维度值 = 维度值/k)。
按照数量,拆分指定维度。
*相同位置的元素相乘 (element-wise)
.matmul 矩阵相乘 = @
取后两维进行矩阵相乘
各个元素单独平方。
aa.sqrt 平方根的倒数
向上取整 向下取整 四舍五入 取整数部分 取小数部分
取指定范围的值,一个参数时表示指定最小值是多少,两个参数时表示指定一个数值范围。
取哪个维度的范数那个维度就会消掉。
prod 所有元素的乘积
argmax 当不指定参数时,将多维元素打平成一维的,返回所有值中最大的值在一维数组中的索引值;当指定一个参数时,如a.argmax(dim=1) 在第一维上求最大值。
max 或者min返回两组值,一组是最值,一组是最值对应的位置。
argmax或者argmin返回一组值,即最值对应的位置。
相比max、min多返回几组值。
kthvalue: 返回第k小的值,及其所在位置。
where 代替了for for if循环操作(CPU操作),可以将运算使用GPU操作,以加速运算。
Gather
是一个查表映射的过程,使用指定的索引,映射取出指定的值放入新的tensor中去。目的:避免使用for for if, 以使用GPU进行加速。