Unified Deep Supervised Domain Adaptation and Generalization

编程入门 行业动态 更新时间:2024-10-25 08:18:42

Unified Deep <a href=https://www.elefans.com/category/jswz/34/1715543.html style=Supervised Domain Adaptation and Generalization"/>

Unified Deep Supervised Domain Adaptation and Generalization

论文概述

问题研究背景:supervised domain adaptation(SDA),源域有大量带标签的数据,目标域仅有少量可使用的数据

问题的难点:目标域数据不足导致概率分布在语义上很难对齐和区分。对齐指的是源域图片类别之间的关系与目标域图片类别之间的关系尽可能的相似;区分指的是同一个domain中,不同类别的特征要尽可能不同。

方法的优点:适应的速度快,仅需要少量的数据就可以获得很不错的效果。易于拓展成DG方法。

具体的方法实现
一般来说,一个DA模型对应的 f u n c t i o n function function可以被看作是两个函数的组合 f = h o g . f=h\ o\ g. f=h o g. 其中 g : χ → Z , χ g:\chi \rightarrow Z,\chi g:χ→Z,χ代表输入的特征空间,Z代表embedding space。 h : z → y 。 h:z\rightarrow y。 h:z→y。h代表利用embedding space中的特征进行预测的函数。对于源域的数据有 f s = h s o g s f_s=h_s\ o\ g_s fs​=hs​ o gs​,对于目标域的数据有 f t = h t o g t f_t=h_t\ o\ g_t ft​=ht​ o gt​。
为了对齐源域和目标域之间数据的分布(即上述中的g),常使用下面形式的loss:
L C A ( g ) = d ( p ( g ( X s ) ) , p ( X t ) ) L_{CA}(g)=d(p(g(X^s)),p(X^t)) LCA​(g)=d(p(g(Xs)),p(Xt))
这个loss的作用就是让源域和目标域中的数据经过映射以后无法被分辨。用原文中的话来说就是 I n t h e e m b e d d i n g s p a c e Z , f e a t u r e s a r e a s s u m e d t o b e d o m a i n i n v a r i n t In\ the\ embedding\ space\ Z,\ features\ are\ assumed\ to\ be\ domain\ invarint In the embedding space Z, features are assumed to be domain invarint
上面的loss对于unsupervised domain adaptation来说很适合,但是存在一个很大的问题:没办法保证不同域之间的语义是对齐的。SDA相较于UDA就可以利用label信息来对齐语义。loss被改写成如下的形式:
L S A ( g ) = ∑ a = 1 C d ( p ( g ( X a s ) , p ( g ( X a t ) ) ) ) L_{SA}(g)=\displaystyle\sum_{a=1}^Cd(p(g(X^s_a),p(g(X^t_a)))) LSA​(g)=a=1∑C​d(p(g(Xas​),p(g(Xat​))))
上述loss被称为semantic alignment loss,d是一个度量距离的函数,具体来说,它用来度量不同域中同一类别样本的特征在被映射到embedding space之后的距离,我们希望这个距离越小越好。

光有上述loss还不行,因为模型学习的方向可能使得所有的类别分布趋同。为了使得不同域的不同类别之间距离尽可能变大,需要加上下面的separation loss
L S ( g ) = ∑ a , b ∣ a ≠ b k ( p ( g ( X a s ) ) , p ( g ( X b t ) ) ) L_S(g)=\displaystyle\sum_{a,b|a\ne b }k(p(g(X^s_a)),p(g(X^t_b))) LS​(g)=a,b∣a​=b∑​k(p(g(Xas​)),p(g(Xbt​)))
k表示相似性函数,当源域中的a类与目标域中b类靠的太近时会施加惩罚。
最后是个用来分类的loss L C L_C LC​,多任务分类一般使用的是交叉熵函数。最后loss的表达形式如下:
L C C S A ( f ) = L C ( h o g ) + L S A ( g ) + L S ( g ) L_{CCSA}(f)=L_C(h\ o\ g) + L_{SA}(g)+L_S(g) LCCSA​(f)=LC​(h o g)+LSA​(g)+LS​(g)
由于目标域中的数据很少,文章中提出逐点计算loss。具体来说,作者将目标域中每个样本与源域中的样本进行配对,每个样本对应的embedding feature之间进行loss的计算。

代码实现

模型部分

class NetWork(nn.Module):def __init__(self):super().__init__()self.cnn = nn.Sequential(nn.Conv2d(1, 32, kernel_size=(3, 3)),nn.ReLU(),nn.Conv2d(32, 32, kernel_size=(3, 3)),nn.ReLU(),nn.MaxPool2d((2, 2)),nn.Dropout(0.25),nn.Flatten(),nn.Linear(1152, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU())self.classifier = nn.Sequential(nn.Dropout(0.5),nn.Linear(84, 10))def forward(self, x):feature = self.cnn(x)prediction = self.classifier(feature)return prediction, feature

模型部分,其中包括一个卷积神经网络用作特征提取器,它对应的就是function g。function h对应的是全连接层用来实现分类功能。模型会返回提取的特征feature以及分类器输出的概率分布。

loss

DA方法loss占据着核心地位。

def csa_loss(x, y, class_eq):margin = 1dist = F.pairwise_distance(x, y, 2)loss = class_eq * dist.pow(2)loss += (1 - class_eq) * (margin - dist).clamp(min=0).pow(2)return loss.mean()

x表示源域样本的embedding feature,y是目标域样本中的embedding feature,class_eq代表源域样本和目标域样本是否是同一种类。首先,计算两个特征图之间各像素点的二范数平方和,这是用于语义对齐的损失。后面,计算separation loss。semantic alighment loss和separation loss只会有一个存在。

训练过程

def train(net, loader):net.train()for i, (src_img, src_label, tar_img, tar_label) in enumerate(loader):src_img = src_img.to(device)src_label = src_label.to(device).long()tar_img = tar_img.to(device)tar_label = tar_label.to(device).long()src_pred, src_feature = net(src_img)_, tar_feature = net(tar_img)ce = entropy_loss(src_pred, src_label)csa = csa_loss(src_feature, tar_feature, (src_label == tar_label).float())loss = (1 - alpha) * ce + alpha * csaoptimizer.zero_grad()loss.backward()optimizer.step()if i % 100 == 0:print("loss : %4f" % (loss.item()))for i, (tar_img, tar_label, src_img, src_label) in enumerate(loader):src_img = src_img.to(device)src_label = src_label.to(device).long()tar_img = tar_img.to(device)tar_label = tar_label.to(device)src_pred, src_feature = net(src_img)_, tar_feature = net(tar_img)ce = entropy_loss(src_pred, src_label)csa = csa_loss(src_feature, tar_feature,(src_label == tar_label).float())loss = (1 - alpha) * ce + alpha * csaoptimizer.zero_grad()loss.backward()optimizer.step()

更多推荐

Unified Deep Supervised Domain Adaptation and Generalization

本文发布于:2023-07-28 18:48:40,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1278886.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:Supervised   Deep   Unified   Generalization   Adaptation

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!