Regularization#
正则化是一种在深度学习中使用的技术,旨在减少模型的过拟合。它通过在目标函数中添加一个正则化项,引入模型复杂度的惩罚,以约束参数的大小或参数之间的关系,从而提高泛化能力并提高模型在未见过的数据上的性能。
Overfitting and Underfitting#
Overfitting#
过拟合(Overfitting)指的是模型过度拟合训练数据,在训练数据上表现很好,但在新数据上的表现较差。这是因为模型过于复杂,过度关注训练数据的细节和噪声,无法很好地泛化到未见过的数据。
Underfitting#
欠拟合(Underfitting)则是指模型无法很好地拟合训练数据,无法捕捉到数据中的关键特征和模式,导致在训练数据和新数据上的表现都不理想。
- 增加模型复杂度
- 特征工程(Feature Engineering)
- 增加训练数据量
- 减少正则化
- 使用更复杂的模型
L1 Regularization (Lasso)#
L1 正则化是通过向模型的损失函数中添加 L1 范数惩罚项来实现的。L1 范数是指模型权重参数的绝对值之和。L1 正则化的目标是最小化损失函数和 L1 范数惩罚项的和。它的效果是将一些权重参数压缩为零,从而实现特征选择和稀疏性。
梯度增加了一个常数项;通过在每次迭代中从未提供信息的特征中减去一小部分,迫使它们的权重为零,从而最终使它们为零。鼓励模型具有较小的系数,并可能导致稀疏性。
L2 Regularization (Ridge)#
L2 正则化是通过向模型的损失函数中添加 L2 范数惩罚项来实现的。L2 范数是指模型权重参数的平方和的平方根。L2 正则化的目标是最小化损失函数和 L2 范数惩罚项的和。它的效果是减小权重参数的值,使模型的权重参数更加平滑和稳定。
效果类似于 weight decay。鼓励模型具有较小的系数,但不会导致稀疏性。
L1 | L2 | |
---|---|---|
作用 | 能够产生更加稀疏的模型 | 有助于计算病态的问题 |
概率先验 | 参数服从拉普拉斯分布 | 参数服从高斯先验分布 |
Dropout#
Dropout 正则化是通过在训练过程中随机将一部分神经元的输出置为零来实现的。具体而言,每个神经元有一定的概率被丢弃,使得模型无法依赖于特定的神经元,从而增强模型的泛化能力。
(1)取平均的作用: 先回到标准的模型即没有 dropout,我们用相同的训练数据去训练 5 个不同的神经网络,一般会得到 5 个不同的结果,此时我们可以采用 “5 个结果取均值” 或者 “多数取胜的投票策略” 去决定最终结果。例如 3 个网络判断结果为数字 9, 那么很有可能真正的结果就是数字 9,其它两个网络给出了错误结果。这种 “综合起来取平均” 的策略通常可以有效防止过拟合问题。因为不同的网络可能产生不同的过拟合,取平均则有可能让一些 “相反的” 拟合互相抵消。dropout 掉不同的隐藏神经元就类似在训练不同的网络,随机删掉一半隐藏神经元导致网络结构已经不同,整个 dropout 过程就相当于对很多个不同的神经网络取平均。而不同的网络产生不同的过拟合,一些互为 “反向” 的拟合相互抵消就可以达到整体上减少过拟合。
(2)减少神经元之间复杂的共适应关系: 因为 dropout 程序导致两个神经元不一定每次都在一个 dropout 网络中出现。这样权值的更新不再依赖于有固定关系的隐含节点的共同作用,阻止了某些特征仅仅在其它特定特征下才有效果的情况。迫使网络去学习更加鲁棒的特征,这些特征在其它的神经元的随机子集中也存在。换句话说假如我们的神经网络是在做出某种预测,它不应该对一些特定的线索片段太过敏感,即使丢失特定的线索,它也应该可以从众多其它线索中学习一些共同的特征。从这个角度看 dropout 就有点像 L1,L2 正则,减少权重使得网络对丢失特定神经元连接的鲁棒性提高。
(3)Dropout 类似于性别在生物进化中的角色:物种为了生存往往会倾向于适应这种环境,环境突变则会导致物种难以做出及时反应,性别的出现可以繁衍出适应新环境的变种,有效的阻止过拟合,即避免环境改变时物种可能面临的灭绝。
- Vanilla Dropout:测试时输出乘以
- Inverted Dropout:(主流方法)训练时数据乘以
import numpy as np
class Dropout:
def __init__(self, dropout_rate):
self.dropout_rate = dropout_rate
self.mask = None
def forward(self, x, is_train):
if is_train:
self.mask = np.random.binomial(1, 1 - self.dropout_rate, size=x.shape) / (1 - self.dropout_rate)
out = x * self.mask
else:
out = x
return out
def backward(self, dout):
dx = dout * self.mask
return dx
Early Stopping#
Early Stopping 是一种基于验证集性能的策略,用于在训练过程中提前停止模型的训练。通过监控模型在验证集上的性能指标(如损失函数或准确率),如果模型的性能在一定的训练轮数内没有改善,就提前终止训练,以避免过拟合。
Batch Normalization#
Batch Normalization 是通过对每个批次(batch)的输入进行归一化,使得网络中间层的输入保持较小的分布变化。它通过减小输入数据的内部协方差移动和缩放来加速网络的训练过程,并且有助于防止梯度消失或梯度爆炸。
适用于 batch 较大、序列长度固定的情形,如 CNN。
$\gamma$ 和 $\beta$ 分别是缩放因子和偏移量。因为减去均值除以方差未必是最好的分布。所以要加入两个可学习变量来完善数据分布以达到比较好的效果。
训练的时候使用滑动平局收集均值和方差,测试的时候直接使用。
import numpy as np
class BatchNormalization:
def __init__(self, epsilon=1e-5, momentum=0.9):
self.epsilon = epsilon
self.momentum = momentum
self.running_mean = None
self.running_var = None
self.gamma = None
self.beta = None
def forward(self, X, training=True):
N, D = X.shape
if self.running_mean is None:
self.running_mean = np.zeros(D)
if self.running_var is None:
self.running_var = np.zeros(D)
if training:
sample_mean = np.mean(X, axis=0)
sample_var = np.var(X, axis=0)
X_normalized = (X - sample_mean) / np.sqrt(sample_var + self.epsilon)
self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * sample_mean
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * sample_var
self.gamma = np.ones(D)
self.beta = np.zeros(D)
else:
X_normalized = (X - self.running_mean) / np.sqrt(self.running_var + self.epsilon)
out = self.gamma * X_normalized + self.beta
return out
BN 解决了什么问题?
神经网络在做非线性变换前的激活输入值随着网络深度加深,其分布逐渐发生偏移(內部协变量偏移)。之所以训练收敛慢,一般是整体分布逐渐往非线性函数的区间两端靠近,导致反向传播底层网络训练梯度消失。BN 就是通过一定的正则化手段,每层神经网络任意神经元输入值的分布强行拉回到均值为 0 方差为 1 的标准正态分布。
Layer Normalization#
适用于变长序列,如 RNN、Transformer。
如果 BN 应用到 NLP 任务,相当于是在对默认了在同一个位置的单词对应的是同一种特征。
Others#
- Elastic Net 正则化:结合了 L1 正则化和 L2 正则化的一种方法。它在损失函数中同时引入 L1 范数和 L2 范数的惩罚项,通过调节两者的权重超参数来平衡正则化的影响。Elastic Net 正则化能够同时实现特征选择和权重缩小的效果。
- 数据增强是一种通过对原始训练数据进行一系列随机变换或扩充来增加数据样本量的技术。这些变换可以包括随机旋转、平移、缩放、翻转等操作。数据增强可以帮助模型更好地泛化,并减轻过拟合的问题。
- 参数共享是一种在神经网络中共享部分参数的技术。在某些具有相似结构或功能的层之间共享参数可以减少模型的参数量,从而降低模型的复杂性和过拟合的风险。参数共享通常用于卷积神经网络(Convolutional Neural Networks)中。