Distilling the Knownledge in a Neural Network

编程入门 行业动态 更新时间:2024-10-25 14:30:01

<a href=https://www.elefans.com/category/jswz/34/1368851.html style=Distilling the Knownledge in a Neural Network"/>

Distilling the Knownledge in a Neural Network

论文要点

需要解决的问题

作者首先指出一种提升机器学习算法能力的万金油方法:在同一数据集上训练不同的模型,然后将这些模型的预测结果取平均值。这种方法简单粗暴,但是成本很高(using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users)。Hinton以及团队提出了一种压缩算法,它能将一个集成模型中知识压缩到便于部署的模型上。

解决问题的方法

本篇论文的精髓思想可以用两个精妙的比喻来涵盖。
1.昆虫形态
许多种昆虫都有幼年和成年两种完全不同的形态,幼年的形态利于从环境汲取能量和养分,成年形态利于迁移和生产。而对于机器学习而言,我们在训练过程中可能需要规模庞大的模型才能从大量的数据中学习到有用的信息结构。但是在部署到生产环境中,如果使用同样笨重的模型,会大大影响用户的体验。基于这个痛点,我们可以设法将训练中笨重模型中的知识迁移到部署环境中的小模型中,这样能在保证模型准确率的同时保证用户的体验。
Once the cumbersome model has been trained, we can then use a different kind of training, which we call “distillation” to transfer the knowledge from the cumbersome model to a small model that is more suitable for deployment. 作者将知识从笨重模型迁移到轻量模型的过程称作蒸馏。

上述中的想法很美好,但是模型和人类之间学习的方式天差地别。我们人类传递知识可以通过书籍、视频或者是口口相传等形式。而模型中的知识太过抽象,在传统的fine-tune方法中,预训练模型和真正训练的模型相同,参数继承可以被看作是知识传递。但是如果两个模型之间的结构不同,复杂模型中的知识又该如何传递到轻量模型中去呢?

作者用了如下一段话描述。

A more abstract view of the knowledge, that frees it from any particular instantiation, is that it is a learned mapping from input vectors to output vectors.

这段话的大致意思是说,知识不是指模型具体学习过哪些样本,而是指的是输入张量到输出张量的映射。但是模型本身就代表一种映射,这个解释无所直接指导我们设计算法。
对于一个分类的判别式模型,它的输出是每个类别的概率。通常情况下,模型会对自己认为是正确答案的那一类给出一个很高的概率,而其它错误类别对应的概率会很小(使用交叉熵损失的必然性)。作者认为,这些不正确类别,它们之间的相对概率代表了模型想要泛化的东西(The relative probabilities of incorrect answers tell us a lot about how the cumbersome model tends to generalize).例如,一张宝马的图片,只有很小的概率会被认成一个垃圾车,但是这个概率绝对要比错认成萝卜的概率大得多。
模型的输出概率中,它认为垃圾车要比萝卜更像宝马,除了正确类别信息,它还告诉了我们类别之间的相互关系。
基于此,作者提出了一种可以直接指导我们设计算法的思想。An obvious way to transfer the generalization ability of the cumbersome model to a small model is to use the class probabilities produced by the cumbersome model as “soft targets” for training the model.将训练好的复杂模型输出的概率分布当作我们训练的目标,这比单单使用one-hot向量作为标签学习到的信息要多得多。

这个方法从目前来看,适用于自监督,半监督和无监督的任务。但从实验效果来看,同时使用复杂模型的soft labels和one-hot向量对应的hard labels效果最好。

2.知识蒸馏
这里是本篇论文第二个精妙的比喻。在训练图像分类时,往往会将最后一层全连接层的输出送入到softmax函数得到概率值。这样会使得正确类别的概率要比错误类别的概率数量级大得多,这会导致用soft label做交叉熵时,错误类别提供的信息量可以忽略不计。
为了使得正确类别的概率与错误类别概率的数量级差距缩小,可以采取如下的方法:
q i = e x p ( z i / T ) ∑ j e x p ( z j / T ) q_i=\frac {exp(z_i/T)} {\sum_j exp(z_j/T)} qi​=∑j​exp(zj​/T)exp(zi​/T)​

Neural networks typically produce class probabilities by using a “softmax” output layer that converts the logit computed for each class into a probability q i q_i qi​, by comparing z i z_i zi​ with the other logits.

这里的T被称为temperature,指的是知识蒸馏的参数,不考虑知识迁移时T设为1。
初中化学中我们曾学习过蒸馏的实验,慢慢提高加热器的温度可以制得蒸馏水。作者此处把模型中的知识看作是待蒸馏的对象,随着T的增高,正确类别和错误类别的概率值之间的数量级随之变小。当T达到某一值时,使用复杂模型当作soft label的效果会达到最佳。

算法的PyTorch简易实现

数据集

使用的是PyTorch自带的CIFAR10数据集。

import torch.utils.data as Data
import torchvision
import torchvision.transforms as transformsdef get_loader():train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])train_set = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=train_transform)test_set = torchvision.datasets.CIFAR10(root='./root',train=False,download=True,transform=test_transform)train_loader = Data.DataLoader(train_set, batch_size=128, shuffle=True,)test_loader = Data.DataLoader(test_set, batch_size=100, shuffle=False)return train_loader, test_loader

模型

模型部分需要准备一个较为复杂的模型用来预训练和一个轻量模型来做蒸馏训练。

import torch.nn.functional as Fclass Block(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.block = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, padding=1, kernel_size=(3, 3)),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.Conv2d(in_channels=out_channels, out_channels=out_channels, padding=1, kernel_size=(3, 3)),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.MaxPool2d(kernel_size=(3, 3), padding=1, stride=2))def forward(self, t):return self.block(t)class Student(nn.Module):def __init__(self):super().__init__()self.block1 = Block(3, 64)self.block2 = Block(64, 128)self.block3 = Block(128, 256)self.fc = nn.Linear(4096, 10)def forward(self, t):# t:[b, 3, 32, 32]t = self.block1(t)  # [b, 64, 16, 16]t = self.block2(t)  # [b, 128, 8, 8]t = self.block3(t)  # [b, 256, 4, 4]t = t.flatten(start_dim=1) # [b, 4096]return self.fc(t)class Teacher(nn.Module):def __init__(self):super().__init__()vgg = vgg16(pretrained=False)self.cnn = list(vgg.children())[0]self.fc1 = nn.Linear(512, 10)def forward(self, t):t = self.cnn(t)t = t.flatten(start_dim=1)t = self.fc1(t)t = F.relu(t)return t

复杂模型指的是Teacher,其中包含了vgg16的卷积部分以及额外添加的用于分类的全连接层, 轻量模型指的是Student,内部结构是三个卷积单元和一个全连接层

损失函数

损失函数是该算法的核心

import torch
import torch.nn.Functional as Fdef soft_loss(output, target, label, temperature, alpha):"""output:是轻量模型预测的概率分布target:是复杂模型输出的概率分布label:真实的ont-hotlabel"""q = F.softmax(target/temperature, dim=1)p = F.log_softmax(output/temperature, dim=1)soft_loss = -torch.mean(torch.sum(q * p, dim=1))hard_loss = F.cross_entropy(output, label)loss = alpha * soft_loss + (1 - alpha) * hard_lossreturn loss

训练

和常规训练图片分类任务一样,不同在于流程。首先,利用数据集对复杂模型进行训练,loss使用的是常规的交叉熵损失。蒸馏过程中,先将样本输入复杂网络进行预测,再利用我们自己设计的损失函数计算loss。

更多推荐

Distilling the Knownledge in a Neural Network

本文发布于:2023-07-28 18:48:39,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1278883.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:Knownledge   Distilling   Network   Neural

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!