Tensorflow模型训练正常但推理输出NaN

发布于 2023-07-13  379 次阅读


问题

使用tf.keras.applications.ResNet101V2作为backbone的多输出多分类模型训练正常,推理时输出nan。使用tf混合精度进行训练,期间有出现过loss nan,降低学习率后解决。在差不多收敛后进行predict测试,发现模型输出nan。

调试

检查模型权重

输出模型weight,搜索nan,发现某些层的权重存在nan的情况,具体来说大部分集中在最后一个BatchNormalization层。

def check_and_initialize_nan_weights(model):
    for layer in model.layers:
        for weight in layer.weights:
            try:
                nan_indices = tf.debugging.check_numerics(weight, "NaN detected in weight")
            except Exception as e:
                # assert "Checking b : Tensor had NaN values" in e.message
                print("NaN detected in weight:", weight.name)
                # layer.kernel_initializer(shape=np.asarray(layer.kernel.shape)), \
                #                    layer.bias_initializer(shape=np.asarray(layer.bias.shape))])
                # 对 NaN 权重进行初始化
                nan_mask = tf.math.is_nan(weight)
                weight.assign(tf.where(nan_mask, tf.ones_like(weight), weight))

检查中间层输出

有时候不仅仅是权重问题,一些计算逻辑也可能导致nan,于是查看了模型中间层的输出,找到到底是哪一层计算逻辑后出现了nan,结果是最后一个resnet block中的命名为conv5_block3_2_bn的BatchNormalization层的moving_variance过大,以及命名为post_bn的全局池化层之前的最后一个BatchNormalization层的moving_meanmoving_variance都出现了nan。

res = model.get_layer('resnet101v2')
bn = res.get_layer('conv5_block3_1_bn')
post_bn = res.get_layer('post_bn')
bn_2 = res.get_layer('conv5_block3_2_bn')
bn_2.set_weights([bn_2.weights[0],bn_2.weights[1],bn_2.moving_mean_initializer(shape=np.asarray(bn_2.moving_mean.shape)),bn_2.moving_variance_initializer(shape=np.asarray(bn_2.moving_variance.shape))])
post_bn.set_weights([post_bn.weights[0],post_bn.weights[1],post_bn.moving_mean_initializer(shape=np.asarray(post_bn.moving_mean.shape)),post_bn.moving_variance_initializer(shape=np.asarray(post_bn.moving_variance.shape))])
input = tf.keras.Input(shape=(224, 224, 3), name="input")
sub = keras.Model(inputs=res.input, outputs=res.get_layer('conv5_block3_out').output)
output = sub(input)
submodel = keras.Model(inputs=input, outputs=output)
submodel.summary()
print(submodel(image))

分析

BatchNormalization

此次模型输出nan是由于BatchNormalization层参数异常造成的,至于为什么训练时候正常是由于BatchNormalization层在训练与推理时不同的行为造成的。

BatchNormalization应用一种变换,该变换将平均输出(均值)保持在接近0的水平,并将输出标准偏差保持在接近1的水平。批量归一化在训练和推理期间的工作方式不同。

在训练期间(即使用fit()或使用参数调用层/模型时training=True),层使用当前批次输入的平均值和标准差标准化其输出。也就是说,对于每个被归一化的通道,该层返回 gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta,其中:

  • epsilon是小常量(可配置为构造函数参数的一部分)
  • gamma是一个学习的缩放因子(初始化为 1),可以通过传递scale=False给构造函数来禁用它。
  • beta是一个学习的偏移因子(初始化为 0),可以通过传递center=False给构造函数来禁用它。

在推理过程中(即,当使用evaluate()predict()或使用参数调用层/模型时training=False(这是默认值),该层使用在训练期间看到的批次的平均值和标准差的移动平均值来规范化其输出。即说,它回来了 gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta

self.moving_meanself.moving_var是不可训练的变量,每次在训练模式下调用层时都会更新,如下所示:

  • moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
  • moving_var = moving_var * momentum + var(batch) * (1 - momentum)

因此,该层仅 在接受与推理数据具有相似统计数据的数据训练后,才会在推理过程中对其输入进行归一化。

混合精度训练

正常情况下模型的权重不应会出现nan的情况,但是在使用混合精度进行训练时,由于精度降低,会出现数值不稳定的情况,导致loss出现nan或者inf,进而导致梯度为nan或者inf,梯度回传更新权重,模型权重就会出现异常。

混合精度是指训练时在模型中同时使用 16 位和 32 位浮点类型,从而加快运行速度,减少内存使用的一种训练方法。通过让模型的某些部分保持使用 32 位类型(通常是输入层与预测头)以保持数值稳定性,可以缩短模型的单步用时,而在评估指标(如准确率)方面仍可以获得同等的训练效果。

float16 数据类型的动态范围比 float32 窄。这意味着大于 65504 的数值会因溢出而变为无穷大,小于 6.0×10−8 的数值则会因下溢而变成零。float32 和 bfloat16 的动态范围要大得多,因此一般不会出现下溢或溢出的问题。

参考

https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

https://github.com/keras-team/keras/issues/17204

https://github.com/keras-team/keras/issues/7157

https://stackoverflow.com/questions/58657003/how-can-i-use-tf-keras-model-summary-to-see-the-layers-of-a-child-model-which-in

https://www.tensorflow.org/guide/mixed_precision


面向ACG编程