PyTorch 模型训练实用教程
作者:余霆嵩
PyTorch 模型训练实⽤教程
前言:
自 2017 年 1 月 PyTorch 推出以来,其热度持续上升,一度有赶超
TensorFlow 的趋势。PyTorch 能在短时间内被众多研究人员和工程师接受并推
崇是因为其有着诸多优点,如采用 Python 语言、动态图机制、网络构建灵活以
及拥有强大的社群等。因此,走上学习 PyTorch 的道路已刻不容缓。
本教程以实际应用、工程开发为目的,着重介绍模型训练过程中遇到的实
际问题和方法。如上图所示,在机器学习模型开发中,主要涉及三大部分,分
别是数据、模型和损失函数及优化器。本文也按顺序的依次介绍数据、模型和
损失函数及优化器,从而给大家带来清晰的机器学习结构。
通过本教程,希望能够给大家带来一个清晰的模型训练结构。当模型训练
遇到问题时,需要通过可视化工具对数据、模型、损失等内容进行观察,分析
并定位问题出在数据部分?模型部分?还是优化器?只有这样不断的通过可视
化诊断你的模型,不断的对症下药,才能训练出一个较满意的模型。
为什么写此教程:
前几年一直在用 Caffe 和 MatConvNet,近期转 PyTorch。当时只想快速地
用上 PyTorch 进行模型开发,然而搜了一圈 PyTorch 的教程,并没有找到一款
本教程仅限于学习交流使用,严禁用于商业用途!
I
PyTorch 模型训练实用教程
作者:余霆嵩
适合的。很多 PyTorch 教程是从学习机器学习(深度学习)的角度出发,以
PyTorch 为工具进行编写,里面介绍很多模型,并且附上模型的 demo。
然而,工程应用开发中所遇到的问题并不是跑一个模型的 demo 就可以的,
模型开发需要对数据的预处理、数据增强、模型定义、权值初始化、模型
Finetune、学习率调整策略、损失函数选取、优化器选取、可视化等等。鉴于
此,我只能自己对着官方文档,一步一步地学习。
起初,只是做了一些学习笔记,后来觉得这些内容应该对大家有些许帮
助,毕竟在互联网上很难找到这类内容的分享,于是此教程就诞生了。
本教程内容及结构:
本教程内容主要为在 PyTorch 中训练一个模型所可能涉及到的方法及函
数,并且对 PyTorch 提供的数据增强方法(22 个)、权值初始化方法(10
个)、损失函数(17 个)、优化器(6 个)及 tensorboardX 的方法(13 个)
进行了详细介绍。
本教程分为四章,结构与机器学习三大部分一致。
第一章,介绍数据的划分,预处理,数据增强;
第二章,介绍模型的定义,权值初始化,模型 Finetune;
第三章,介绍各种损失函数及优化器;
第四章,介绍可视化工具,用于监控数据、模型权及损失函数的变化。
本教程适用读者:
1.想熟悉 PyTorch 使用的朋友;
2.想采用 PyTorch 进行模型训练的朋友;
3.正采用 PyTorch,但无有效机制去诊断模型的朋友;
干货直达:
1.6 transforms 的二十二个方法
2.2 权值初始化的十种方法
3.1 PyTorch 的十七个损失函数
3.3 PyTorch 的十个优化器
3.4 PyTorch 的六个学习率调整方法
4.1 TensorBoardX
项目代码:https://github.com/tensor-yu/PyTorch_Tutorial
本教程仅限于学习交流使用,严禁用于商业用途!
II
PyTorch 模型训练实用教程
作者:余霆嵩
意见反馈:yts3221@126.com
学习交流 QQ 群:为了更好的帮助大家学习和理解 PyTorch 以及机器学
习相关知识,特建立一个 QQ 群,供大家交流,本文的最新修改也会同步到 QQ
群及 GitHub。QQ 群号:671103375
本教程仅限于学习交流使用,严禁用于商业用途!
III
PyTorch 模型训练实用教程
作者:余霆嵩
⽬录
第一章 数 据 ....................................................... 1
1.1 Cifar10 转 png ................................................ 1
1.2 训练集、验证集和测试集的划分 ............................... 2
1.3 让 PyTorch 能读你的数据集 ................................... 2
1.4 图片从硬盘到模型 ........................................... 5
1.5 数据增强 与 数据标准化 ..................................... 7
1.6 transforms 的二十二个方法 .................................... 10
1.随机裁剪:transforms.RandomCrop........................... 11
2.中心裁剪:transforms.CenterCrop ............................ 12
3.随机长宽比裁剪 transforms.RandomResizedCrop ............... 12
4.上下左右中心裁剪:transforms.FiveCrop ...................... 12
5.上下左右中心裁剪后翻转: transforms.TenCrop ................. 13
6.依概率 p 水平翻转 transforms.RandomHorizontalFlip ............ 13
7.依概率 p 垂直翻转 transforms.RandomVerticalFlip .............. 13
8.随机旋转:transforms.RandomRotation ........................ 13
9.resize:transforms.Resize ................................... 14
10.标准化:transforms.Normalize .............................. 14
11.转为 tensor:transforms.ToTensor ........................... 14
12.填充:transforms.Pad ..................................... 14
13.修改亮度、对比度和饱和度:transforms.ColorJitter ............ 15
14.转灰度图:transforms.Grayscale ............................ 15
15.线性变换:transforms.LinearTransformation() ................. 15
16.仿射变换:transforms.RandomAffine ........................ 15
17.依概率 p 转为灰度图:transforms.RandomGrayscale............ 16
18.将数据转换为 PILImage:transforms.ToPILImage .............. 16
19.transforms.Lambda ........................................ 16
20.transforms.RandomChoice(transforms) ........................ 16
21.transforms.RandomApply(transforms, p=0.5) ................... 16
22.transforms.RandomOrder ................................... 16
本教程仅限于学习交流使用,严禁用于商业用途!
IV
PyTorch 模型训练实用教程
作者:余霆嵩
第二章 模 型 ...................................................... 17
2.1 模型的搭建 ................................................ 17
2.1.1 模型定义的三要素 ..................................... 17
2.1.2 模型定义多说两句 ..................................... 18
2.1.3 nn.Sequetial ........................................... 21
2.2 权值初始化的十种方法 ...................................... 22
2.2.1 权值初始化流程 ....................................... 22
2.2.2 常用初始化方法 ....................................... 23
1. Xavier 均匀分布 .......................................... 24
2. Xavier 正态分布 .......................................... 24
3. kaiming 均匀分布 ......................................... 24
4. kaiming 正态分布 ......................................... 24
5. 均匀分布初始化 ......................................... 25
6. 正态分布初始化 ......................................... 25
7. 常数初始化 ............................................. 25
8. 单位矩阵初始化 ......................................... 25
9. 正交初始化 ............................................. 25
10. 稀疏初始化 ............................................ 26
11. 计算增益 .............................................. 26
权值初始化杂谈 ............................................ 26
2.3 模型 Finetune ............................................. 27
第三章 损失函数与优化器 ........................................... 30
3.1 PyTorch 的十七个损失函数 ................................... 30
1. L1loss................................................... 30
2. MSELoss ................................................ 30
3. CrossEntropyLoss ......................................... 31
4. NLLLoss ................................................ 32
5. PoissonNLLLoss .......................................... 33
6. KLDivLoss .............................................. 34
7. BCELoss ................................................ 35
本教程仅限于学习交流使用,严禁用于商业用途!
V
PyTorch 模型训练实用教程
作者:余霆嵩
8. BCEWithLogitsLoss ....................................... 36
9. MarginRankingLoss ....................................... 36
10. HingeEmbeddingLoss ..................................... 37
11. MultiLabelMarginLoss .................................... 37
12. SmoothL1Loss ........................................... 38
13. SoftMarginLoss .......................................... 39
14. MultiLabelSoftMarginLoss ................................. 39
15. CosineEmbeddingLoss .................................... 40
16. MultiMarginLoss ......................................... 40
17. TripletMarginLoss ........................................ 41
3.2 优化器基类:Optimizer ..................................... 42
3.3 PyTorch 的十个优化器 ...................................... 44
1. torch.optim.SGD ........................................ 44
2. torch.optim.ASGD ....................................... 45
3. torch.optim.Rprop ...................................... 45
4. torch.optim.Adagrad .................................... 46
5. torch.optim.Adadelta ................................... 46
6. torch.optim.RMSprop .................................... 46
7. torch.optim.Adam(AMSGrad) .............................. 47
8. torch.optim.Adamax ..................................... 47
9. torch.optim.SparseAdam ................................. 47
10.torch.optim.LBFGS ...................................... 48
3.4 PyTorch 的六个学习率调整方法 .............................. 48
1. lr_scheduler.StepLR .................................... 48
2. lr_scheduler.MultiStepLR ............................... 49
3. lr_scheduler.ExponentialLR ............................. 49
4. lr_scheduler.CosineAnnealingLR ......................... 49
5. lr_scheduler.ReduceLROnPlateau ......................... 51
6. lr_scheduler.LambdaLR .................................. 52
学习率调整小结 ............................................ 53
step 源码阅读 ............................................. 54
第四章 监控模型——可视化 .......................................... 56
本教程仅限于学习交流使用,严禁用于商业用途!
VI
PyTorch 模型训练实用教程
作者:余霆嵩
4.1 TensorBoardX .............................................. 56
1. add_scalar() ........................................... 57
2. add_scalars() .......................................... 58
3. add_histogram() ........................................ 59
4. add_image() ............................................ 61
补充 torchvision.utils.make_grid() ........................ 61
5. add_graph() ............................................ 62
6. add_embedding() ........................................ 63
7. add_text() ............................................. 64
8. add_video() ............................................ 65
9. add_figure() ........................................... 65
10. add_image_with_boxes() ................................ 65
11. add_pr_curve() ........................................ 65
12. add_pr_curve_raw() .................................... 65
13. export_scalars_to_json() .............................. 65
4.2 卷积核可视化 .............................................. 66
4.3 特征图可视化 .............................................. 68
4.4 梯度及权值分布可视化 ...................................... 70
4.5 混淆矩阵及其可视化 ........................................ 74
本教程仅限于学习交流使用,严禁用于商业用途!
VII
PyTorch 模型训练实用教程
作者:余霆嵩
第一章 数 据
1.1 Cifar10 转 png
俗话说得好,巧妇难为无米之炊,若没有数据,我们什么也做不了。在本
教程中,为了统一大家的数据,这里采用 cifar-10 的测试集,共 10000 张图片
作为源数据,模拟真实场景中的数据。
第一步:下载 cifar-10-python.tar.gz
下载 cifar-10-python.tar.gz,存放到 /Data 文件夹下,并且解压,获得文件
夹/Data/cifar-10-batches-py/
下载方式:
1. 官网:http://www.cs.toronto.edu/~kriz/cifar.html
2. 百度云: https://pan.baidu.com/s/1NGV8g2iBAhHwjQZWTjGEbg 提取
码: p3rh
第二步:运行 1_1_cifar10_to_png.py
运行代码:Code/1_data_prepare/1_1_cifar10_to_png.py
可在文件夹 Data/cifar-10-png/raw_test/下看到 0-9 个文件夹,对应 9 个类别。
脚本中未将训练集解压出来,这里只是为了实验,因此未使用过多的数据。这里仅将
测试集中的 10000 张图片解压出来,作为原始图片,将从这 10000 张图片中划分出训练集
(train),验证集(valid),测试集(test)。
运行完成,在 Data/cifar-10-png/raw_test 下将有 10 个文件夹,对应 10 个类别
.
└── raw_test
├── 0
├── 1
├── 2
├── 3
本教程仅限于学习交流使用,严禁用于商业用途!
1