深度学习系列33:有标签的CGAN:Pix2Pix/Pix2PixHD/cycleGAN

编程入门 行业动态 更新时间:2024-10-24 10:26:15

<a href=https://www.elefans.com/category/jswz/34/1769690.html style=深度学习系列33:有标签的CGAN:Pix2Pix/Pix2PixHD/cycleGAN"/>

深度学习系列33:有标签的CGAN:Pix2Pix/Pix2PixHD/cycleGAN

1. 从GAN到CGAN

GAN的训练数据是没有标签的,如果我们要做有标签的训练,则需要用到CGAN。
对于图像来说,我们既要让输出的图片真实,也要让输出的图片符合标签c。Discriminator输入便被改成了同时输入c和x,输出要做两件事情,一个是判断x是否是真实图片,另一个是x和c是否是匹配的。
在下面两个情况中,左边虽然输出图片清晰,但不符合c;右边输出图片不真实。因此两种情况中D的输出都会是0。

我们来看下简单的示例代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image# 独热编码
# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform=one_hot,download=False)
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)# 定义生成器
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear1 = nn.Linear(10, 128 * 7 * 7)self.bn1 = nn.BatchNorm1d(128 * 7 * 7)self.linear2 = nn.Linear(100, 128 * 7 * 7)self.bn2 = nn.BatchNorm1d(128 * 7 * 7)self.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3, 3),padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4, 4),stride=2,padding=1)self.bn4 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 1,kernel_size=(4, 4),stride=2,padding=1)def forward(self, x1, x2):x1 = F.relu(self.linear1(x1))x1 = self.bn1(x1)x1 = x1.view(-1, 128, 7, 7)x2 = F.relu(self.linear2(x2))x2 = self.bn2(x2)x2 = x2.view(-1, 128, 7, 7)x = torch.cat([x1, x2], axis=1)x = F.relu(self.deconv1(x))x = self.bn3(x)x = F.relu(self.deconv2(x))x = self.bn4(x)x = torch.tanh(self.deconv3(x))return x# 定义判别器
# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.linear = nn.Linear(10, 1*28*28)self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值def forward(self, x1, x2):x1 =F.leaky_relu(self.linear(x1))x1 = x1.view(-1, 1, 28, 28)x = torch.cat([x1, x2], axis=1)x = F.dropout2d(F.leaky_relu(self.conv1(x)))x = F.dropout2d(F.leaky_relu(self.conv2(x)))x = self.bn(x)x = x.view(-1, 128*6*6)x = torch.sigmoid(self.fc(x))return x# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)# 损失计算函数
loss_function = torch.nn.BCELoss()# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())fig = plt.figure(figsize=(4, 4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i + 1)plt.imshow((predictions[i] + 1) / 2, cmap='gray')plt.axis("off")plt.show()
noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)
print(label_seed)
# print(label_seed_onehot)# 开始训练
D_loss = []
G_loss = []
# 训练循环
for epoch in range(150):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader.dataset)# 对全部的数据集做一次迭代for step, (img, label) in enumerate(dataloader):img = img.to(device)label = label.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)d_optim.zero_grad()real_output = dis(label, img)d_real_loss = loss_function(real_output,torch.ones_like(real_output, device=device))d_real_loss.backward() #求解梯度# 得到判别器在生成图像上的损失gen_img = gen(label,random_noise)fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output, device=device))d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()  # 优化# 得到生成器的损失g_optim.zero_grad()fake_output = dis(label, gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output, device=device))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_loss.item()g_epoch_loss += g_loss.item()with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)if epoch % 10 == 0:print('Epoch:', epoch)generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)

2. Pix2pix:像素级别转换

这里是尝试地址:/
使用Pix2Pix神经网络模型实现论文中预定义的任务:黑白简笔画到彩图、平面房屋到立体房屋和航拍图到地图等功能:

Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件, 需要输入到G和D中。 G的输入是x(x 是需要转换的图片),输出是生成的图片G(x)。 D则需要分辨出{x,G(x)}和{x, y}。

这里的生成器模型我们采用U-Net:

在pix2pix中,作者就是把L1 loss 和GAN loss相结合使用,因为作者认为L1 loss 可以恢复图像的低频部分,而GAN loss可以恢复图像的高频部分。判别器使用patchGAN。



我们看一些代码说明:

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.down1 = Downsample(3, 64)self.down2 = Downsample(64, 128)self.down3 = Downsample(128, 256)self.down4 = Downsample(256, 512)self.down5 = Downsample(512, 512)self.down6 = Downsample(512, 512)self.up1 = Upsample(512, 512)self.up2 = Upsample(1024, 512)self.up3 = Upsample(1024, 256)self.up4 = Upsample(512, 128)self.up5 = Upsample(256, 64)self.last = nn.ConvTranspose2d(128, 3,kernel_size=3,stride=2,padding=1,output_padding=1)def forward(self, x):x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])x6 = self.up3(x6, is_drop=True)x6 = torch.cat([x3, x6], dim=1)x6 = self.up4(x6)x6 = torch.cat([x2, x6], dim=1)x6 = self.up5(x6)x6 = torch.cat([x1, x6], dim=1)x6 = torch.tanh(self.last(x6))return x6# 判别器
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.down1 = Downsample(6, 64)self.down2 = Downsample(64, 128)self.down3 = Downsample(128, 256)self.conv = nn.Conv2d(256, 512, 3, 1, 1)self.bn = nn.BatchNorm2d(512)self.last = nn.Conv2d(512, 1, 3, 1)def forward(self, anno, img):x = torch.cat([anno, img], dim=1)  # batch*6*H*Wx = self.down1(x, is_bn=False)x = self.down2(x)x = F.dropout2d(self.down3(x))x = F.dropout2d(F.leaky_relu(self.conv(x)))x = F.dropout2d(self.bn(x))x = torch.sigmoid(self.last(x))return x

3. Pix2PixHD

在pix2pix的基础上,增加了一个“从糙到精生成器(coarse-to-fine generator)”、一个多尺度鉴别器架构和一个健壮的对抗学习目标函数。
1)生成器部分提高分辨率:将生成器U-net拆分成两个子网络G1和G2进行训练:前者输入和输出的分辨率保持一致(如 1024 x 512),后者输出尺寸(2048x1024)是输入尺寸(1024x512)的4倍(长宽各两倍)。如果想要得到更高分辨率的图像,只需要增加更多的局部增强网络即可(如 G={G1,G2,G3})

2)判别器部分将深度改为宽度:使用三个相同结构的判别器,分别处理不同尺寸的输入。
3)损失函数更稳健:除了PatchGAN的损失,还加上了样本与GT使用判别器网络和VGG16网络提取特征后进行的Element-wise loss
4)输入加入高频特征向量,例如图像的边缘信息,与输入的语义标签连接到一起作为输入。
5)额外学习一个Feature encoder网络,可以将原图转化为features,用来控制图像的颜色、纹理信息。

4. CycleGAN:风格转换

pix2pixGAN有一个明显的缺点就是,在进行训练的时候必须提供成对的数据集。比如当我们想生成梵高风格的画时,梵高本人画的作品肯定是相对较少的,这个时候就可以考虑使用cycleGAN。cycleGAN适用于非配对的图像到图像转换:

其原理可以概括为将一类图片转成成另一类图片,比如,现有两个样本空间X、Y,我们希望把X空间中的样本转换成Y空间中的样本。这种转换只是风格上的转换,实际X Y 的内容是不一样的。实际的目标就是学习从X到Y的映射,假设该映射为F,它就对应着GAN中的生成器,F就可以将X中的图片A转换为Y中的图片F(A)。
为了实现这个过程,我们需要两个生成器 G_AB 和 G_BA:

首先是单向loss的组成:
判别 loss: 判别器 D_B 是用来判断输入的图片是否是真实的 B 图片,这个流程和GAN是一致的。

生成 loss:生成器用来重建图片 a,目的是希望生成的图片 G_BA(G_AB(a)) 和原图 a 尽可能的相似,那么可以很简单的采取 L1 loss 或者 L2 loss。除了GAN loss,还包含如下loss:
① cycle-loss:也就是循环一致损失。因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GA-B生成一张假图像,那么要能够使用另外一个生成器 GB-A来努力恢复成原始图像。此过程必须满足循环一致性

② 等价loss:我们要求 G A B ( b ) = b G_{AB}(b)=b GAB​(b)=b,以及 G B A ( a ) = a G_{BA}(a)=a GBA​(a)=a。

下面来看下示例代码:
获取苹果橙子数据:

# 加载训练数据
apples_path = glob.glob('data/trainA/*.jpg')
oranges_path = glob.glob('data/trainB/*.jpg')transform = transforms.Compose([transforms.ToTensor(),  # 0-1归一化transforms.Normalize(0.5, 0.5),  # -1,1])class AppleOrangeDataset(data.Dataset):def __init__(self, img_path):self.img_path = img_pathdef __getitem__(self, index):img_path = self.img_path[index]pil_img = Image.open(img_path)pil_img = transform(pil_img)return pil_imgdef __len__(self):return len(self.img_path)apple_dataset = AppleOrangeDataset(apples_path)
oranges_dataset = AppleOrangeDataset(oranges_path)

基于Unet结构定义上 / 下采样模块,接着定义生成器:

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.down1 = Downsample(3, 64)self.down2 = Downsample(64, 128)self.down3 = Downsample(128, 256)self.down4 = Downsample(256, 512)self.down5 = Downsample(512, 512)self.down6 = Downsample(512, 512)self.up1 = Upsample(512, 512)self.up2 = Upsample(1024, 512)self.up3 = Upsample(1024, 256)self.up4 = Upsample(512, 128)self.up5 = Upsample(256, 64)self.last = nn.ConvTranspose2d(128, 3,kernel_size=3,stride=2,padding=1,output_padding=1)def forward(self, x):x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])x6 = self.up3(x6, is_drop=True)x6 = torch.cat([x3, x6], dim=1)x6 = self.up4(x6)x6 = torch.cat([x2, x6], dim=1)x6 = self.up5(x6)x6 = torch.cat([x1, x6], dim=1)x6 = torch.tanh(self.last(x6))return x6

接下来是鉴别器:

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.down1 = Downsample(3, 64)             # 128self.down2 = Downsample(64, 128)           # 64self.last = nn.Conv2d(128, 1, 3)def forward(self, img):x = self.down1(img)x = self.down2(x)x = torch.sigmoid(self.last(x))return x

我们需要定义两个生成器和两个鉴别器:


gen_AB = Generator().to(device)
gen_BA = Generator().to(device)
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)# 同时对两个生成器进行优化
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()),lr=2e-4, betas=(0.5, 0.999))
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))

训练过程如下:

D_loss = []  # 记录训练过程中判别器loss变化
G_loss = []  # 记录训练过程中生成器loss变化# 开始训练
for epoch in range(50):D_epoch_loss = 0G_epoch_loss = 0for step, (real_A, real_B) in enumerate(zip(apples_dl, oranges_dl)):# GAN 训练gen_optimizer.zero_grad()# identity losssame_B = gen_AB(real_B)identity_B_loss = l1loss_fn(same_B, real_B)same_A = gen_BA(real_A)identity_A_loss = l1loss_fn(same_A, real_A)# GAN lossfake_B = gen_AB(real_A)D_pred_fake_B = dis_B(fake_B)gan_loss_AB = bceloss_fn(D_pred_fake_B,torch.ones_like(D_pred_fake_B, device=device))fake_A = gen_BA(real_B)D_pred_fake_A = dis_A(fake_A)gan_loss_BA = bceloss_fn(D_pred_fake_A,torch.ones_like(D_pred_fake_A, device=device))# cycle consistanse lossrecovered_A = gen_BA(fake_B)cycle_loss_ABA = l1loss_fn(recovered_A, real_A)recovered_B = gen_AB(fake_A)cycle_loss_BAB = l1loss_fn(recovered_B, real_B)# total_lossg_loss = (identity_B_loss + identity_A_loss + gan_loss_AB + gan_loss_BA+ cycle_loss_ABA + cycle_loss_BAB)g_loss.backward()gen_optimizer.step()# dis_A 训练dis_A_optimizer.zero_grad()dis_A_real_output = dis_A(real_A)  # 判别器输入真实图片dis_A_real_loss = bceloss_fn(dis_A_real_output,torch.ones_like(dis_A_real_output, device=device))dis_A_fake_output = dis_A(fake_A.detach())  # 判别器输入生成图片dis_A_fake_loss = bceloss_fn(dis_A_fake_output,torch.zeros_like(dis_A_fake_output, device=device))dis_A_loss = (dis_A_real_loss + dis_A_fake_loss) * 0.5dis_A_loss.backward()dis_A_optimizer.step()# dis_B 训练dis_B_optimizer.zero_grad()dis_B_real_output = dis_B(real_B)  # 判别器输入真实图片dis_B_real_loss = bceloss_fn(dis_B_real_output,torch.ones_like(dis_B_real_output, device=device))dis_B_fake_output = dis_B(fake_B.detach())  # 判别器输入生成图片dis_B_fake_loss = bceloss_fn(dis_B_fake_output,torch.zeros_like(dis_B_fake_output, device=device))dis_B_loss = (dis_B_real_loss + dis_B_fake_loss) * 0.5dis_B_loss.backward()dis_B_optimizer.step()

更多推荐

深度学习系列33:有标签的CGAN:Pix2Pix/Pix2PixHD/cycleGAN

本文发布于:2024-02-25 16:23:00,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1699589.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:深度   标签   系列   cycleGAN   Pix2Pix

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!