用PyTorch代码生成对抗网络(GAN)

编程入门 行业动态 更新时间:2024-10-06 16:29:27

用PyTorch<a href=https://www.elefans.com/category/jswz/34/1771412.html style=代码生成对抗网络(GAN)"/>

用PyTorch代码生成对抗网络(GAN)

一、说明

        2014年,蒙特利尔大学的伊恩·古德费罗(Ian Goodfellow)和他的同事发表了一篇令人惊叹的论文,向世界介绍了GANs生成对抗网络。通过计算图和博弈论的创新组合,他们表明,如果有足够的建模能力,相互竞争的两个模型将能够通过普通的旧反向传播进行共同训练。

二、原理和实现

        这些模型扮演着两个不同的(字面意思是对抗的)角色。给定一些真实的数据集R,G是生成器,试图创建看起来像真实数据的假数据,而D鉴别器,从真实集或G获取数据并标记差异。 Goodfellow的比喻(这是一个很好的比喻)是,G就像一群伪造者,试图将真实的画作与他们的作品相匹配,而D是试图区分的侦探团队。(除了在这种情况下,伪造者G永远看不到原始数据——只能看到D的判断。他们就像盲目的伪造者。

        在理想情况下,D和G都会随着时间的推移而变得更好,直到G基本上成为真品的“伪造大师”,而D不知所措,“无法区分两个分布”。

        在实践中,Goodfellow所展示的是,G将能够在原始数据集上执行一种形式的无监督学习,找到某种以(可能)低维方式表示该数据的方法。正如Yann LeCun的名言,无监督学习是真正AI的“蛋糕”。

        这种强大的技术似乎必须需要一公吨的代码才能开始,对吧?不。使用 PyTorch,我们实际上可以在不到 50 行代码中创建非常简单的 GAN。实际上只有 5 个组件需要考虑:

  • R:原始的、真实的数据集
  • I:作为熵源进入发生器的随机噪声
  • G:尝试复制/模仿原始数据集的生成器
  • D:试图区分 G 的输出和 R 的鉴别器
  • 实际的“训练”循环,我们教 G 欺骗 D 和 D 提 G

1.) R:在我们的例子中,我们将从最简单的 R 开始——钟形曲线。此函数采用平均值和标准差,并返回一个函数,该函数提供具有这些参数的高斯样本数据的正确形状。在我们的示例代码中,我们将使用平均值 4.0 和标准差 1.25。

2.) I:生成器的输入也是随机的,但为了使我们的工作更难一点,让我们使用均匀分布而不是正态分布。这意味着我们的模型 G 不能简单地移动/缩放输入来复制 R,而必须以非线性方式重塑数据。

3.) G:生成器是一个标准的前馈图——两个隐藏层,三个线性映射。我们正在使用双曲正切激活函数,因为我们是老派的。G 将从 I 获取均匀分布的数据样本,并以某种方式模仿来自 R 的正态分布样本——而从未见过 R

4.) D:鉴别器代码与G的生成器代码非常相似;具有两个隐藏层和三个线性映射的前馈图。这里的激活函数是一个sigmoid——没什么花哨的,人们。它将从 R 或 G 获取样本,并将输出 0 到 1 之间的单个标量,解释为“假”与“真”。换句话说,这和神经网络所能得到的一样温和

5.)最后,训练循环在两种模式之间交替:首先在真实数据上训练D,在假数据上训练D,并带有准确的标签(可以想象这是警察学院);然后训练 G 用不准确的标签愚弄 D(这更像是《海洋十一人》中的那些准备蒙太奇)。这是一场善与恶之间的斗争,人。

        即使你以前没有见过PyTorch,你也可能知道发生了什么。在第一部分(绿色)中,我们通过 D 推送两种类型的数据,并对 D 的猜测与实际标签应用可微标准。这种推动是“前进”的一步;然后,我们显式调用 'back()' 来计算梯度,然后在 d_optimizer step() 调用中更新 D 的参数。这里使用了 G,但没有训练。

        然后在最后一个(红色)部分中,我们对 G 做了同样的事情 — 请注意,我们也通过 D 运行 G 的输出(我们本质上是给伪造者一个侦探来练习),但我们在此步骤中没有优化或更改 D。我们不希望侦探D学习错误的标签。因此,我们只调用 g_optimizer.step()。

和。。。仅此而已。还有其他一些样板代码,但GAN特定的东西只是这5个组件,没有别的。

        经过D和G之间几千轮的禁舞,我们得到了什么?鉴别器D变得非常好(而G慢慢向上移动),但是一旦达到一定的功率水平,G就会有一个有价值的对手并开始改进。 真的进步了。

        超过 5,000 轮训练,每轮训练 D 20 次,然后训练 G 20 次,G 输出的平均值超过 4.0,但随后回到一个相当稳定、正确的范围内(左)。同样,标准差最初在错误的方向上下降,但随后上升到所需的 1.25 范围(右),与 R 匹配。

        好的,所以基本统计数据最终与 R 匹配。更高的时刻呢?分布的形状看起来是否正确?毕竟,你当然可以有一个均值为 4.0 和标准差为 1.25 的均匀分布,但这并不真正匹配 R。让我们看一下 G 发出的最终分布:

        不错。右尾比左尾肥一点,但可以说,歪斜和峰度让人想起原始的高斯。

        G几乎完美地恢复了原始分布R——而D则蜷缩在角落里,喃喃自语,无法分辨事实和虚构。这正是我们想要的行为(参见 Goodfellow 中的图 1)。从不到 50 行代码。

        现在,警告一句:GANs可能会很挑剔。而且脆弱。当他们进入奇怪的状态时,他们往往不会在没有一点哄骗的情况下出来。运行我的示例代码十次(每次超过 5,000 轮)显示了以下十个分布:

        十次运行中有 4 次产生了相当好的最终分布——类似于高斯,均值为 5,在正确的球场上有标准差。但是其中两次运行没有 — 在一种情况下(运行 #6),有一个凹分布,平均值约为 0.10,而在最后一次运行 (#11) 中,有一个狭窄的峰值在 -<>!当你开始在几乎任何环境中应用 GAN 时,你会看到这种现象——GAN 并不像普通的监督学习工作流程那样稳定。但是当它们工作时,它们看起来几乎是神奇的

        Goodfellow将继续发表许多其他关于GAN的论文,包括2016年的一篇描述一些实际改进的宝石,包括这里采用的小批量判别方法。这是他在2 NIPS上展示的2016小时教程。对于TensorFlow用户,这是Aylien关于GANs的平行帖子。

      

三、代码实现

  好了,说够了。去看看代码。

#!/usr/bin/env python# Generative Adversarial Networks (GAN) example in PyTorch. Tested with PyTorch 0.4.1, Python 3.6.7 (Nov 2018)
# See related blog post at /@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variablematplotlib_is_available = True
try:from matplotlib import pyplot as plt
except ImportError:print("Will skip plotting; matplotlib is not available.")matplotlib_is_available = False# Data params
data_mean = 4
data_stddev = 1.25# ### Uncomment only one of these to define what data is actually sent to the Discriminator
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
#(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
#(name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)print("Using data [%s]" % (name))# ##### DATA: Target data and generator input datadef get_distribution_sampler(mu, sigma):return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussiandef get_generator_input_sampler():return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian# ##### MODELS: Generator model and discriminator modelclass Generator(nn.Module):def __init__(self, input_size, hidden_size, output_size, f):super(Generator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)self.f = fdef forward(self, x):x = self.map1(x)x = self.f(x)x = self.map2(x)x = self.f(x)x = self.map3(x)return xclass Discriminator(nn.Module):def __init__(self, input_size, hidden_size, output_size, f):super(Discriminator, self).__init__()self.map1 = nn.Linear(input_size, hidden_size)self.map2 = nn.Linear(hidden_size, hidden_size)self.map3 = nn.Linear(hidden_size, output_size)self.f = fdef forward(self, x):x = self.f(self.map1(x))x = self.f(self.map2(x))return self.f(self.map3(x))def extract(v):return v.data.storage().tolist()def stats(d):return [np.mean(d), np.std(d)]def get_moments(d):# Return the first 4 moments of the data providedmean = torch.mean(d)diffs = d - meanvar = torch.mean(torch.pow(diffs, 2.0))std = torch.pow(var, 0.5)zscores = diffs / stdskews = torch.mean(torch.pow(zscores, 3.0))kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussianfinal = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))return finaldef decorate_with_diffs(data, exponent, remove_raw_data=False):mean = torch.mean(data.data, 1, keepdim=True)mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])diffs = torch.pow(data - Variable(mean_broadcast), exponent)if remove_raw_data:return torch.cat([diffs], 1)else:return torch.cat([data, diffs], 1)def train():# Model parametersg_input_size = 1      # Random noise dimension coming into generator, per output vectorg_hidden_size = 5     # Generator complexityg_output_size = 1     # Size of generated output vectord_input_size = 500    # Minibatch size - cardinality of distributionsd_hidden_size = 10    # Discriminator complexityd_output_size = 1     # Single dimension for 'real' vs. 'fake' classificationminibatch_size = d_input_sized_learning_rate = 1e-3g_learning_rate = 1e-3sgd_momentum = 0.9num_epochs = 5000print_interval = 100d_steps = 20g_steps = 20dfe, dre, ge = 0, 0, 0d_real_data, d_fake_data, g_fake_data = None, None, Nonediscriminator_activation_function = torch.sigmoidgenerator_activation_function = torch.tanhd_sampler = get_distribution_sampler(data_mean, data_stddev)gi_sampler = get_generator_input_sampler()G = Generator(input_size=g_input_size,hidden_size=g_hidden_size,output_size=g_output_size,f=generator_activation_function)D = Discriminator(input_size=d_input_func(d_input_size),hidden_size=d_hidden_size,output_size=d_output_size,f=discriminator_activation_function)criterion = nn.BCELoss()  # Binary cross entropy: .html#bcelossd_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)for epoch in range(num_epochs):for d_index in range(d_steps):# 1. Train D on real+fakeD.zero_grad()#  1A: Train D on reald_real_data = Variable(d_sampler(d_input_size))d_real_decision = D(preprocess(d_real_data))d_real_error = criterion(d_real_decision, Variable(torch.ones([1])))  # ones = trued_real_error.backward() # compute/store gradients, but don't change params#  1B: Train D on faked_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labelsd_fake_decision = D(preprocess(d_fake_data.t()))d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1])))  # zeros = faked_fake_error.backward()d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]for g_index in range(g_steps):# 2. Train G on D's response (but DO NOT train D on these labels)G.zero_grad()gen_input = Variable(gi_sampler(minibatch_size, g_input_size))g_fake_data = G(gen_input)dg_fake_decision = D(preprocess(g_fake_data.t()))g_error = criterion(dg_fake_decision, Variable(torch.ones([1])))  # Train G to pretend it's genuineg_error.backward()g_optimizer.step()  # Only optimizes G's parametersge = extract(g_error)[0]if epoch % print_interval == 0:print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %(epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))if matplotlib_is_available:print("Plotting the generated distribution...")values = extract(g_fake_data)print(" Values: %s" % (str(values)))plt.hist(values, bins=50)plt.xlabel('Value')plt.ylabel('Count')plt.title('Histogram of Generated Distribution')plt.grid(True)plt.show()train()

更多推荐

用PyTorch代码生成对抗网络(GAN)

本文发布于:2024-03-23 17:43:40,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1740967.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:代码   网络   PyTorch   GAN

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!