蒸馏网络

参考论文: Distilling the Knowledge in Neural Network
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
preprint arXiv:1503.02531, 2015
NIPS 2014 Deep Learning Workshop

概况

蒸馏网络的主要想法,其最本质的思想是来源于昆虫记里面的故事:
“蝴蝶以毛毛虫的形式吃树叶积攒能量逐渐成长,最后变换成蝴蝶这一终极形态来完成繁殖。”

简单来说,就是利用一个教师模型(大参数模型)去训练一个学生模型(小参数模型)。尽管学生模型可能最终依然达不到教师模型的准确性,但是被老师教过的学生模型会比自己单独训练的学生模型更加强大。

设计思想

在用神经网络训练大规模数据集时,为了处理复杂的数据分布。主要由两种做法

  • 一种做法是建立复杂的神经网络模型,例如含有上百层的残差网络,这种复杂的网络往往含有多达几百万个参数。
  • 另一种做法是混合多种模型,将几个大规模的神经网络在同一个数据集上训练好,然后综合多个模型,得到最终的分类结果。

以上两种方案,重新训练成本高,并且模型庞大难以部署

于是就有了蒸馏网络,最基本的想法就是将大模型学习出来的知识作为先验,将先验知识传递到小规模的神经网络中,之后实际应用中部署小规模的神经网络。这样做有三点依据:

  • 大规模神经网络得到的类别预测包含了数据结构间的相似性
  • 有了先验的小规模神经网络只需要很少的新场景数据就能够收敛;
  • Softmax函数随着温度变量(temperature)的升高分布更均匀

船新的softmax定义

假设有一个数组$$V$$,$$V_i$$代表$$V$$中的第$$i$$个元素,该元素的$$softmax$$值为

$$P_{i}=\cfrac{e^{V_i}}{\sum_{j} e^{V_j}}$$

之后,一般会利用交叉熵损失函数来优化参数,其形式为

$$\arg \min _{\theta_{F}}-\frac{1}{|\mathcal{X}|} \sum_{X \in \mathcal{X}} \sum_{i \in 0 . . N} Y_{i}(X) \log P_{i}(X)$$

而现在,重新定义softmax,变为

$$P_{i}^T = \cfrac{e^{V_i/T}}{\sum_{j} e^{V_j / T}}$$

其中的T为温度参数,这是一个超参数。不难发现,如果T设置为1,那就是普通的softmax函数。而T越高,意味着随着T的上升,给出的概率值的方差会越小(作者解释是约软,我理解这个“软”的意思,就是方差越小),目标的分布会更加均匀

def softmaxT(x,t = 1):
    """Compute the softmax in a numerically stable way."""
    x = x - np.max(x)
    exp_x = np.exp(x/t)
    softmax_x = exp_x / np.sum(exp_x)
    return softmax_x

softmaxT([1, 2, 3],1)
> array([0.09003057, 0.24472847, 0.66524096])
softmaxT([1, 2, 3],1.5)
> array([0.1483371 , 0.28892122, 0.56274169])
softmaxT([1, 2, 3],2)
> array([0.18632372, 0.30719589, 0.50648039])
softmaxT([1, 2, 3],3)
> array([0.23023722, 0.32132192, 0.44844086])

对此,作者提出了训练的过程:

  1. 首先用较大的T值来训练教师模型,复杂的神经网络能够产生更均匀分布的软目标。
  2. 之后小规模的神经网络用相同的T值来学习由大规模神经产生的软目标,接近这个软目标从而学习到数据的结构分布特征。
    于是,学生网络的交叉熵损失函数的形式就变为

  3. 最后在实际应用中,将T值恢复到1,让类别概率偏向正确类别。

代码实现

虽说我按照作者的思路,与论文中提出的模型实现了一下这个蒸馏网络。但是和作者的结果非常不相似。。。关键作者也没有细说它的网络是啥样子的。或许他没用到卷积层?

softmaxT

本来我是打算手撸一个激活函数的,在论坛上查到了1.8的代码。然而,发现很多接口tf2.0改起来很麻烦。后来想到可以去改源码。于是就有了如下代码。如有不得当之处,还请指正

def softmaxT(x, t, axis=-1):
    x = x / t
    ndim = tf.keras.backend.ndim(x)
    if ndim == 2:
        return tf.nn.softmax(x)
    elif ndim > 2:
        e = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))
        s = tf.reduce_sum(e, axis=axis, keepdims=True)
        return e / s
    else:
        raise ValueError('Cannot apply softmax to a tensor that is 1D. '
                         'Received input: %s' % (x,))

学生网络和教师网络

#####定义老师模型——包含三层卷积层的CNN模型
def teacher_model(temperature=1):
    input_ = tf.keras.layers.Input(shape=(28,28,1))
    x = tf.keras.layers.Conv2D(32,(3,3),padding = "same",activation="relu")(input_)
    x = tf.keras.layers.MaxPool2D((2,2))(x)
    x = tf.keras.layers.Conv2D(64,(3,3),padding= "same",activation="relu")(x)
    x = tf.keras.layers.MaxPool2D((2,2))(x)
    x = tf.keras.layers.Conv2D(64,(3,3),padding= "same",activation="relu")(x)
    x = tf.keras.layers.MaxPool2D((2,2))(x)
    x = tf.keras.layers.Flatten()(x)
    # he large net at a temperature of 20, it achieved 74 test errors.
    x = tf.keras.layers.Dense(256,activation="relu")(x)
    x = tf.keras.layers.Dropout(0.2)(x)   
    x = tf.keras.layers.Dense(10)(x)
    out = softmaxT(x,t=temperature)
    model = tf.keras.Model(inputs=input_,outputs=out)
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                 optimizer=tf.keras.optimizers.Adam(lr=0.01),
                 metrics=["accuracy"])
    model.summary()
    return model

###定义学生模型——— 一层含512个神经元的全连接层
def student_model(temperature=1):
    '''
    When the distilled net had 300 or more units in each of its two hidden layers,
    all temperatures above8 gave fairly similar results. 
    But when this was radically reduced to 30 units per layer, 
    temperaturesin the range 2.5 to 4 worked 
    significantly better than higher or lower temperatures

    :return: 
    '''
    input_ = tf.keras.layers.Input(shape=(28,28,1))
    x = tf.keras.layers.Flatten()(input_)
    x = tf.keras.layers.Dense(512,activation="sigmoid")(x)
    out = tf.keras.layers.Dense(10)(x)
    out = softmaxT(out,t=temperature)
    model = tf.keras.Model(inputs=input_,outputs=out)
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                 optimizer=tf.keras.optimizers.Adam(lr=0.01),
                 metrics=["accuracy"])
    model.summary()
    return model

t_model  = teacher_model(20)
t_model.fit(data_train,label_train,batch_size=64,epochs=2,validation_data=(data_test,label_test))

s_model = student_model(1)
s_model.fit(data_train,label_train,batch_size=64,epochs=10,validation_data=(data_test,label_test))

令人匪夷所思的是,虽然我温度按照原文的说法,大网络用了20°,但是正确率居然出奇得高,只用了两步就达到了96%。而原文就达到了74%

Epoch 2/2
60000/60000 [==============================] - 31s 522us/sample - loss: 0.1486 - accuracy: 0.9590 - val_loss: 0.1321 - val_accuracy: 0.9603

如果直接训练学生模型,由于层数很浅,只有82%的准确率

Epoch 10/10
60000/60000 [==============================] - 3s 45us/sample - loss: 0.4890 - accuracy: 0.8512 - val_loss: 0.5782 - val_accuracy: 0.8246

利用教师的输出训练学生模型

def teach_student(teacher_out, student_model,data_train,data_test,label_test):
    t_out = teacher_out
    s_model = student_model
    for l in s_model.layers:
        l.trainable = True     

    label_test = tf.keras.utils.to_categorical(label_test)

    model = tf.keras.Model(s_model.input,s_model.output)
    model.compile(loss="categorical_crossentropy",
                 optimizer="adam",metrics=["accuracy"])
    model.fit(data_train,t_out,batch_size= 64,epochs = 5,validation_data=(data_test,label_test))

    s_predict = np.argmax(model.predict(data_test),axis=1)
    s_label =  np.argmax(label_test,axis=1)
    print(accuracy_score(s_predict,s_label))

t_out = t_model.predict(data_train)
teach_student(t_out,s_model,data_train,data_test,label_test)

之后利用教师模型的输出来训练学生模型,速度还是很快的,就算我用CPU也只花了半分钟不到。在10次迭代后达到了91%的正确率,并且看上去还可以继续增长。

Epoch 10/10
60000/60000 [==============================] - 3s 50us/sample - loss: 0.4303 - accuracy: 0.9000 - val_loss: 0.3050 - val_accuracy: 0.9118

防御性蒸馏

参考论文 Distillation as a Defense to Adversarial Perturbations against Deep Neural Networks

关于本文提出的防御性蒸馏方法,与一般的蒸馏不同的是学生网络与教师网络所使用的结构是相同的。作者发现对教师网络使用JSMA算法进行攻击产生了95%的成功率,但经过蒸馏后,成功率降到了0.5%。

但是蒸馏网络只能用来防御JSMA算法。在之前的论文阅读中,不难发现,其实JSMA是基于前向导数的。而现在softmax被修改了。具体推导如下

$$\begin{aligned}
\left.\frac{\partial F_{i}(X)}{\partial X_{j}}\right|_{T} &=\frac{\partial}{\partial X_{j}}\left(\frac{e^{z_{i} / T}}{\sum_{l=0}^{N-1} e^{z_{l} / T}}\right) \\
&=\frac{1}{g^{2}(X)}\left(\frac{\partial e^{z_{i}(X) / T}}{\partial X_{j}} g(X)-e^{z_{i}(X) / T} \frac{\partial g(X)}{\partial X_{j}}\right) \\
&=\frac{1}{g^{2}(X)} \frac{e^{z_{i} / T}}{T}\left(\sum_{l=0}^{N-1} \frac{\partial z_{i}}{\partial X_{j}} e^{z_{l} / T}-\sum_{l=0}^{N-1} \frac{\partial z_{l}}{\partial X_{j}} e^{z_{l} / T}\right) \\
&=\frac{1}{T} \frac{e^{z_{i} / T}}{g^{2}(X)}\left(\sum_{l=0}^{N-1}\left(\frac{\partial z_{i}}{\partial X_{j}}-\frac{\partial z_{l}}{\partial X_{j}}\right) e^{z_{l} / T}\right)
\end{aligned}$$

看上去非常复杂,但是其实只要把导数除法的公式套里面就能够推出来了。

很显然,这个推导告诉了我们两点内容

  • 偏导数的大小(或者说雅克比矩阵中的每一个元素都)与温度T成反比。PS:这将会导致需要更多的迭代次数才能达成攻击
  • 对数在被指数化之前,先除以了温度T。PS:这将导致模型对于输入的成分并不是这么敏感。

作者后来又提及,T选取100时,防御效果最佳。

打破防御性蒸馏网络

参考论文: Defensive Distillation is Not Robust to Adversarial Examples

这篇文章非常简短,只有3页。但是和之后介绍的C&W攻击的作者是同一位。

作者介绍了一种修改过的JSMA_Z算法,可以打破蒸馏网络的防御。

修改的点如下:

首先是对输入的处理,在利用输入之前,先将输入人为地除以T后套上softmax函数,实际上就还原了学生网络的输入。

$$\hat{F}(\theta, x)=\operatorname{softmax}(Z(\theta, x) / T)$$

之后是对像素选取的处理,现在不再是选取一对符合原方程的像素点,而是找单个的像素点。其应该满足

$$p^{*}=\underset{p}{\arg \max } 2 \frac{\partial \hat{F}(x)_{t}}{\partial x_{p}}-\sum_{j=0}^{9} \frac{\partial \hat{F}(x)_{j}}{\partial x_{p}}$$

这里之所以写9是因为用了minist数据集。