目录
一、案例描述
- 一、案例描述
- 二、数据集介绍
- 三、构建神经网络类(class)
-
- 3.1 设计神经网络层
- 3.2 损失函数和权重更新
- 3.3 实现forward()方法
- 3.4 定义train()函数
- 3.5 训练可视化
- 四、训练模型
- 五、测试模型性能
- 六、查看训练效果
- 七、拓展——算法改进
-
- 7.1 损失函数
- 7.2 激活函数
- 7.3 优化器
本文以MNIST数据集为例,详解PyTorch搭建神经网络对MNIST数据集进行分类。
二、数据集介绍加载数据集利用datasets.MNIST()函数,其用法如下: datasets.MNIST(root, train=True, transform=None, download=False