Batch Normalization
为什么需要BN
- 网络中存在 Internal covariate shift (ICS)
那么ICS会导致什么问题呢?
- 每次网络更新后,深层的 layer 可能都会面对一个全新的输入分布,这就会导致网络的收敛速度变得更慢
Circular transclusion detected: 80_Resources/Deep-Learning/Internal-covariate-shift
Circular transclusion detected: 80_Resources/Deep-Learning/Internal-covariate-shift
- 即使解决了 ICS 问题后,也需要考虑不同网络层对于不同的数据分布的要求 假设我们通过简单的归一化解决了 ICS 问题,但是实际上就会导致所有 layer 的输入都是相同的分布,这就降低了网络的特征表达能力。
BN 是怎么做的?
Step 1: 归一化 Step 2: 对规范化后的数据进行数据变换 这里引入两个可学习的变换参数 和
关于 和 的反向传播公式推导
如何计算 和 ?
理论上,BN 是在所有的数据上的同一个通道进行的。假设一个样本的特征图 ,我们计算所有样本在网络该位置的输出中第 个通道 () 的特征的均值 方差 .
实际上,由于实际训练中我们往往采用的是mini-batch的方式,引入了Batch size的大小N,因此这里我们只计算同一个batch内所有样本的第 个通道 () 的特征的均值 方差 .
Note
概言之,每次应用BN时,我们会计算出个均值和方差,并且对该batch内所有样本对应的通道利用对应的均值和方差进行归一化。即沿着维度进行。
测试阶段 BN 是怎么工作的?
显然,在测试阶段,我们不会再更新所有的 和 。
但是,怎么计算测试阶段的均值和方差呢? 如果我们此时只需要推理一个测试样本,那么此时一个 batch 内 (只有 1 个样本) 的均值和方差已经是整体数据的有偏估计了,直接使用可能造成性能损失。
实际应用中,在训练阶段,我们提前保留每组 mini-batch 的训练样本在网络中每一层的 和 。并利用该统计量对测试阶段的 和 进行无偏估计,得到在测试阶段实际使用的 和 ,即
\mu_{\text {test }} &=E\left(\mu_{\text {batch }}\right) \\ \sigma_{\text {test }}^{2} &=\frac{m}{m-1} E\left(\sigma_{\text {batch }}^{2}\right) \end{aligned}$$ 此后,即可正常的执行 BN 了 $$B N\left(X_{\text {test }}\right)=\gamma \cdot \frac{X_{\text {test }}-\mu_{\text {test }}}{\sqrt{\sigma_{\text {test }}^{2}+\epsilon}}+\beta$$ ## BN 放在哪里? 常见的使用使用是放在激活函数的前面,即 conv -> bn -> relu ## BN的优势 1. 缓解梯度消失 2. 加快网络收敛速度,增加网络在训练过程中的稳定性 3. 降低网络对于参数的敏感性,简化调参过程(比如学习率的调节,权重初始化的方法) 4. 引入部分正则的功能。 > 不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音(但同时也可能是挑战)。与Dropout通过关闭神经元给网络训练带来噪音类似,在一定程度上对模型起到了正则化的效果。 ## 什么时候不能用 BN? 1. Batch size 很小时。当一个 batch 中样本数量小的时候,batch 内的均值方差波动会很大,此时不能使用 BN 2. 在生成任务中或者 low level 的任务中 (超分,不知道去噪中用不用呢?) 不能用,[参考](https://www.zhihu.com/question/62599196)。 概言之,可能是由于每层的分布对于生成的任务来说依然很重要,加入 BN 后需要网络拿出一部分参数做这部分的恢复,因此会降低表现。 3. 在 RNN 中,若序列长度不是固定的,BN 就无法计算。 4. 训练数据库一定要Shuffle,否则可能会导致不同批次的均值方差差的太多 ## 使用 Numpy 实现 BN 的前向传播 参考本笔记:[[使用 Numpy 实现 BN 的前向传播|手撕BN]] --- 参考资料 - [Batch Normalization原理与实战 - 知乎 (zhihu.com)](https://zhuanlan.zhihu.com/p/34879333) - [Batch Normalization (BN) | LogM's Blog (imlogm.github.io)](https://imlogm.github.io/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/batch-normalization/)