您当前的位置: 首页 >  pytorch

wendy_ya

暂无认证

  • 1浏览

    0关注

    342博文

    0收益

  • 0浏览

    0点赞

    0打赏

    0留言

私信
关注
热门博文

PyTorch重难点(三)——详解PyTorch的自动求导(梯度)机制

wendy_ya 发布时间:2021-11-17 20:27:57 ,浏览量:1

目录
    • 一、自动求导机制
    • 二、实例介绍
      • 2.1 案例描述
      • 2.2 曲线可视化
      • 2.3 梯度计算
      • 2.4 查看梯度
      • 2.5 完整代码
    • 三、方法二

一、自动求导机制

神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情。 而PyTorch深度学习框架可以帮助我们自动地完成这种求梯度运算。 在计算梯度之前,需要首先明确哪个变量需要计算梯度,将需要计算梯度的张量的requires_grad参数设置为True。 Pytorch一般通过反向传播 backward 方法 实现这种求梯度计算。该方法求得的梯度将存在对应自变量张量的grad属性下,然后利用grad属性即可实现求梯度的计算。 除此之外,也能够调用torch.autograd.grad 函数来实现求梯度计算。

这就是Pytorch的自动求导机制。

二、实例介绍 2.1 案例描述

计算 y = a ∗ x 2 + b ∗ x + c y=a*x^2+b*x+c y=a∗x2+b∗x+c在x=2处的导数,其中a=1,b=-2,c=1,也即 y = x 2 − 2 x + 1 y=x^2-2x+1 y=x2−2x+1。

2.2 曲线可视化

为了方便理解,首先我们利用Matplotlib库来可视化一下曲线。 曲线 y = x 2 − 2 x + 1 y=x^2-2x+1 y=x2−2x+1在x∈[0,3]处图像如下: 在这里插入图片描述 可以发现,当x=2.0时,y=1.0。

曲线 y = x 2 − 2 x + 1 y=x^2-2x+1 y=x2−2x+1的导数即 y = 2 x − 2 y=2x-2 y=2x−2在x∈[0,3]处图像如下: 在这里插入图片描述 可以发现,当x=2.0时,y=2.0。

2.3 梯度计算

要计算当x=2时y的梯度,即dy/dx,PyTorch需要指导它依赖于哪个张量以及依赖关系的数学表达式,之后便能计算出dy/dx。

y.backward()  #计算梯度

上面的代码完成所有这些步骤,通过观察y,PyTorch发现它来自 x 2 − 2 x + 1 x^2-2x+1 x2−2x+1,并自动算出梯度 d y / d x = 2 x − 2 dy/dx=2x-2 dy/dx=2x−2。 同时,这行代码也计算出梯度的数值,并与x的实际值一同存储在张量x中。

2.4 查看梯度

可以通过grad属性查看梯度值:

x.grad    #查看x=2时的梯度

运行结果: tensor(4.)

2.5 完整代码

完整代码如下:

import torch
# f(x) = a*x**2 + b*x + c的导数
x=torch.tensor(3.,requires_grad=True)  #设置x需要被求导
a=torch.tensor(1.)
b=torch.tensor(-2.)
c=torch.tensor(1.)
y=a*torch.pow(x,2)+b*x+c

y.backward()  #计算梯度
x.grad    #查看x=2时的梯度
三、方法二

方法二可以通过调用torch.autograd.grad 函数实现梯度计算,这时就不需要backward 方法了。

完整代码如下:

import numpy as np 
import torch 

# f(x) = a*x**2 + b*x + c的导数
x = torch.tensor(0.0,requires_grad = True) # x需要被求导
a = torch.tensor(1.0)
b = torch.tensor(-2.0)
c = torch.tensor(1.0)
y = a*torch.pow(x,2) + b*x + c

# create_graph 设置为 True 将允许创建更高阶的导数
dy_dx = torch.autograd.grad(y,x,create_graph=True)[0]
print(dy_dx)  # 求导结果

运行结果: tensor(4.)

ok,以上便是本文的全部内容了,如果对你有所帮助,记得点个赞哟~

关注
打赏
1659256378
查看更多评论
立即登录/注册

微信扫码登录

0.0459s