您当前的位置: 首页 >  pytorch

惊鸿一博

暂无认证

  • 3浏览

    0关注

    535博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

深度学习_pytorch_深度学习中的tensor介绍及常用操作

惊鸿一博 发布时间:2021-11-15 11:09:37 ,浏览量:3

目录

1. pytorch中的数据类型

1.1 标量——维度为0——用于loss

1.2 向量——维度为1——用于bias、线性输入数据

1.3 二维tensor

1.4 三维tensor

1.5 四维tensor

1.6 其他   

2. 创建tensor

3. tensor中的索引与切片

4. tensor维度变换

4.1 view 与 reshape

4.2 Squeeze 与 unsqueeze

4.3 Expand 与 repeat 维度拓展

4.3 转置 .t  只能适用于 2D矩阵

4.4 Transpose 交换维度

4.5 permute 交换维度,指定维度顺序

4.6 Brodcasting 自动扩展

5. tensor的拼接和拆分

5.1 cat

5.2 stack

5.3 split

5.4 chunk

6. tensor的基本运算

6.1 加减乘除

6.2 矩阵相乘

6.3 三维及以上矩阵相乘

6.4 pow

6.5 Exp 及log 

6.6 近似函数

6.7 数值剪切clamp

7. tensor的属性统计

7.1 范数 norm

7.2  统计属性 mean sum min max prod argmax argmin

7.3 keepdim

7.4 Top-k

7.5 比较操作 gt eq equal

7.6 高阶操作 Where Gather

引言:pytorch是面向计算的GPU加速库,所以里面的所有操作对象都是tensor(张量)。本文主要介绍pytorch中的数据类型,tensor的创建,索引与切片,维度变换、拼接与拆分、基本数学运算、属性统计等函数功能及示例。(参考:深度学习入门_哔哩哔哩_bilibili)

1. pytorch中的数据类型

 

 

 数据类型检查:

 

1.1 标量——维度为0——用于loss

 

1.2 向量——维度为1——用于bias、线性输入数据

 

1.3 二维tensor

1.4 三维tensor

比如: 10个句子,每句20个单词,每个单词100个feature表示(one hot)

1.5 四维tensor

比如: b张照片c个通道,一个图片大小hxw大小

1.6 其他   

numel 属性 : num of element

 4704 = 2x3x28x28

2. 创建tensor

 

 

3. tensor中的索引与切片

 上图中[ ]中的逗号,表示的是维度与维度之间的并列。

tensor中切片,冒号:

a = tensor.randn(4, 3, 28,28)

  • 仅一个冒号:表示取当前维度所有值 , 比如 a[:, :, :, :]
  • 一个冒号+数字: 如n: 表示从取当前维度n到最后; 如 :n 表示从0取到n-1 , 比如 a[:10, 2:, :, :] 只取前10张图片的2通道和3通道
  • 一个冒号+2个数字:如 n:m 表示从n取到m-1, 比如 a[10:20, :, :, :] 取第10张到第19张图片
  • 两个冒号 0:10:2 从0到9,隔行采样,采样间隔是2,此处是 0,2,4,6,8。或者 ::2, 表示当前维所有值,从0开始间隔2采样

 三个点: … 表示剩余的所有维度

 取出指定条件下的值(打平成了一维)

 

4. tensor维度变换 4.1 view 与 reshape

 view操作的基本原则: numel不变

特别注意: 数据的存储和维度顺序非常重要,需要时刻记住!

view与reshape区别: reshape是view的升级版。

       因为历史上view方法已经约定了共享底层数据内存,返回的Tensor底层数据不会使用新的内存,如果在view中调用了contiguous方法,则可能在返回Tensor底层数据中使用了新的内存,这样打破了之前的约定,破坏了对之前的代码兼容性。为了解决用户使用便捷性问题,PyTorch在0.4版本以后提供了reshape方法,实现了类似于 tensor.contigous().view(*args)的功能,如果不关心底层数据是否使用了新的内存,则使用reshape方法更方便。

4.2 Squeeze 与 unsqueeze

减小维度 与 增加维度

注意 -1是指倒数第一位

 

 

 

 

 

4.3 Expand 与 repeat 维度拓展

expand:将维度拓展到指定值

repeat:  将维度拓展指定倍数

使用expand时,拓展前后维度应该一致,且可拓展,比如从1->N可以,从 3-》N则不行,因为怎么复制数据系统不知道。 从1-》N 系统知道直接复制即可。

expand指定拓展后的维度,在原有的内存上进行拓展

repeat的参数,表示每个维度重复的倍数!会申请新的内存,进行拓展,比较占内存

4.3 转置 .t  只能适用于 2D矩阵

4.4 Transpose 交换维度

交换维度时,一定要注意:经过view后维度会发生变化,若想确保各维度有意义,一定要跟踪好维度的变换关系,确保前后能对应的上。

 

4.5 permute 交换维度,指定维度顺序

维度交换,指定维度的顺序,(内部交换多次transpose)

 

4.6 Brodcasting 自动扩展

自动地根据数据,对tensor进行 unsqueeze/expand操作,来插入维度拓展数据,以实现同维度数据之间的计算。

 

 

 

 扩展内部规则:

 

 

 

 几个例子:

5. tensor的拼接和拆分 5.1 cat

在指定维度上进行合并(前提: 维度数量一样,只有一个维度的值不同)

 

5.2 stack

会增加一个新的维度,新的维度的值等于第一个参数的个数。(要求操作的两个tensor维度数目 、各维度值应完全一致)

5.3  split

按照长度,拆分指定维度。当传入一个list时,按照list中指定的值,进行维度值的拆分,当传入一个常数k时,在指定维度上平均拆分(新的拆分的维度值 = 维度值/k)。

5.4 chunk

按照数量,拆分指定维度。

6. tensor的基本运算 6.1 加减乘除

6.2 矩阵相乘

*相同位置的元素相乘 (element-wise)

.matmul 矩阵相乘 = @

 

6.3 三维及以上矩阵相乘

取后两维进行矩阵相乘

6.4 pow

各个元素单独平方。

aa.sqrt 平方根的倒数

6.5 Exp 及log 

6.6 近似函数

向上取整 向下取整 四舍五入 取整数部分 取小数部分

6.7 数值剪切clamp

取指定范围的值,一个参数时表示指定最小值是多少,两个参数时表示指定一个数值范围。

7. tensor的属性统计 7.1 范数 norm

取哪个维度的范数那个维度就会消掉。

7.2  统计属性 mean sum min max prod argmax argmin

prod 所有元素的乘积

argmax 当不指定参数时,将多维元素打平成一维的,返回所有值中最大的值在一维数组中的索引值;当指定一个参数时,如a.argmax(dim=1) 在第一维上求最大值。

 max 或者min返回两组值,一组是最值,一组是最值对应的位置。

argmax或者argmin返回一组值,即最值对应的位置。

7.3 keepdim

7.4 Top-k

相比max、min多返回几组值。

kthvalue: 返回第k小的值,及其所在位置。

 

7.5 比较操作 gt eq equal

7.6 高阶操作 Where Gather

where 代替了for for if循环操作(CPU操作),可以将运算使用GPU操作,以加速运算。

Gather

是一个查表映射的过程,使用指定的索引,映射取出指定的值放入新的tensor中去。目的:避免使用for for if, 以使用GPU进行加速。

 

 

 

关注
打赏
1663399408
查看更多评论
立即登录/注册

微信扫码登录

0.0418s