ICCV 2021"/>
ICCV 2021
点击下方卡片,关注“CVer”公众号
AI/CV重磅干货,第一时间送达
这篇文章基于optimal transport的理论,从几何角度出发,把生成对抗模型解释为真实数据流形上 metric 和 measure 之间的互动。作者认为这是数学理论和机器学习的一次美妙结合,希望能激发大家对生成对抗模型的更多思考。感兴趣的欢迎点击论文链接(单位:Microsoft,University of Delaware):
Manifold Matching via Deep Metric Learning for Generative Modeling
.10777
传统的对抗生成网络( GAN)从统计学角度出发,通过匹配 real/fake data distribution 的 mean, moments, 等统计信息来构造生成模型。在这个框架下,判别器在训练过程中的行为可以解读为:在学习 real/fake data distribution 之间的距离。在训练任务结束后,我们一般只保存生成器的参数;对于判别器我们很难从中获取更多关于 real data set的有用信息。
在这篇文章里我们提出了一个新的生成模型:MvM (Manifold-matching via Metric-learning)。类似于传统的GAN模型,MvM由两个神经网络组成:distribution generator和 metric generator。在整个训练过程中,这两个神经网络交替学习直至收敛。不同于GAN的是,MvM完全从几何角度出发:假设 real data set 是嵌入在高维欧几里德空间中的一个黎曼流形 (manifold),这样的话,
1. distribution generator 可以看成是在生成一个邻近 real data manifold 的概率分布 (measure);
2. metric generator 可以看成是在生成 real data manifold 上的一个度量 (metric)。
Table 1. GANs 和 MvM 的主要区别
在交替学习的过程中,metric generator 可以为distribution generator的manifold-matching objective 提供更好的metric;反过来,distribution generator 则可以为metric generator的metric-learning objective 提供更好的hard negative samples。最终的收敛其实是一个双赢的结果:distribution generator 学到了如何生成fake data distribution,而metric generator 则学到了比欧式距离更能反映real data set内在的几何结构的metric。
Figure 1. 度量学习和流形匹配的交互学习示意图。
实验方面,我们在无监督图像生成和超分辨率两个任务上验证了方法的灵活性和有效性。
我们首先对假设合理性,尤其是metric learning 的作用,做出了验证。在无监督图像生成任务的训练过程中,我们在每个epoch随机选取1024个真实数据样本,并用学习到的度量计算了这些真实样本的度量矩阵的特征值。结果发现随着训练的进行,特征值变得越来越小而且均匀 (Figure 2)。这意味着在学习过程中度量学习对真实数据流形有“熨平”的效果,使得数据流形具有各向同性。
Figure 2. 1024个真实数据随机样点的度量矩阵的前十个特征根在度量学习过程中的变化趋势。
另外对于manifold matching,我们对比了匹配不同几何特征的学习效果。结果表明同时匹配centroid和p-diameter的效果要好于只匹配其中一个几何特征,并给出了在二维空间上的直观阐释 (Figure 3)。我们还发现该方法的训练对batch size有很好的鲁棒性 (Table 2)。
Table 2. 各种batch size 所需要的训练消耗时间的对比。
Figure 3. (a,b,c)不同目标函数下,真实数据(绿色)和生成数据(橙色)的 UMAP投影图。(a)只有centroid项;(b)只有diameter项;(c)centroid+diameter.
在验证MvM生成图像的性能方面,我们的实验结果也优于在相同的模型架构下采用不同的GAN损失函数的实验结果。由于图像生成的能力与所选用的网络架构密切相关,而在文章中作为评估生成模型损失函数所使用的标准的ResNet架构更适用于小图生成,为了测试我们所提出的损失函数在大图生成任务上的表现,后续我们尝试了在大图生成上表现优异的StyleGAN2架构。下图为在FFHQ数据集上随机生成的512x512图像 (trained at ~150K iterations):
Figure 4. StyleGAN2 架构下,用MvM生成的高清人脸图片。
最后在超分辨率任务中,我们加入了低分辨率图像 (LR) 与高分辨率图像 (HR) 的监督损失函数作为MvM框架在监督学习任务中的一个使用方法。我们进行了两种实验:
(A) 将MvM作为GAN adversarial loss的替代;
(B) 把MvM作为perceptual/naturalness loss的替代,并在相同的实验条件下进行了结果对比 (Table 3)。
可以看到MvM在两种metric (similarity-based和perceptual-based)下的表现提升。下图为在DIV2K数据集x4任务上的生成样例,可以看到通过MvM损失函数生成出来的图像相比GAN具有更真实的纹理特征。
Figure 5. DIV2K数据集SISR任务生成图片的纹理对比。
Table 3. 不同设置下的评价分数对比。具体设置见Table 4。
Table 4. GAN-based SISR方法的各种训练设置。
在以上两类任务中,为了公平对比,我们仅使用了相同并且简单的模型架构来对比不同的损失函数。在实际应用中,可以尝试用MvM框架搭配不同的模型架构、训练方法从而达到更好的训练效果。
ICCV和CVPR 2021论文和代码下载
后台回复:CVPR2021,即可下载CVPR 2021论文和代码开源的论文合集
后台回复:ICCV2021,即可下载ICCV 2021论文和代码开源的论文合集
后台回复:Transformer综述,即可下载最新的两篇Transformer综述PDF
重磅!Transformer交流群成立
扫码添加CVer助手,可申请加入CVer-Transformer微信交流群,方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch和TensorFlow等群。
一定要备注:研究方向+地点+学校/公司+昵称(如Transformer+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群
▲长按加小助手微信,进交流群
▲点击上方卡片,关注CVer公众号
整理不易,请点赞和在看
更多推荐
ICCV 2021
发布评论