Pix2Pix(从零实现图像风格转换任务)

编程入门 行业动态 更新时间:2024-10-26 10:39:11

Pix2Pix(从零实现<a href=https://www.elefans.com/category/jswz/34/1771430.html style=图像风格转换任务)"/>

Pix2Pix(从零实现图像风格转换任务)

本篇博文的内容是利用CGAN从零实现图像风格转换任务,本文用的数据集图片如下,目标是由右边的仅包含人物线条的图片生成左边的彩色图(通俗的来说就是给黑白图片上色)。

各位可以根据提供的地址从Kaggle网站上将图片下载。动漫图片数据集地址
数据集下载下来有文件夹重复的情况,最后仅保留三个文件夹即可,如下图

CGAN论文的简要介绍

论文的作者提出,普通的GAN都是期望生成器能够学习到将一个噪声分布(比如高斯分布或者是均匀分布)映射成我们想要的概率分布。
r a n d o m z random \ z random z G : z → y \ \bold G:z \rarr y  G:z→y
但是在CGAN中,我们的式子变成了
r a n d o m z , i n p u t _ i m a g e x random \ z,\ \ input \_image\ x random z,  input_image x G : x , z → y \ \bold G:{x, z} \rarr y  G:x,z→y
这里的input_image指的是仅包含线条的图片

实作上,我参考的Github项目并没有往Generator中输入random noise z。我的理解是本来Generator的目的就是将一个分布映射成我们想要的分布,增加一个noise分布似乎并没有增添有用的信息。论文的作者也提到了增加noise以后,结果的随机性增加的有限,当然读者可以在实现时增加一个noise来对比效果。


上图是Generator的结构简图,它包含了两个部分,一部分是将输入图片进行下采样的encoder,另一部分是进行上采样的decoder。作者考虑到输入(仅包含线条的图片)与输出(上了色的图片)在结构上具有对称性,比如两种图片的线条轮廓是一样的。所以作者提出了将encoder中的特征图concatenate到了decoder的对应位置上,图中的虚线就表示了这种关系。

而CGAN中discriminator的特殊之处在于,它接收的输入是 x , y x,y x,y concatenate的结果,并且它的输出不再是一个标量,而是代表一个四维的张量,论文作者将其称为一个patch。



以上就是关于模型设计的全部信息,不过在实现时,我们将BatchNorm替换成InstanceNorm

模型设计

#Generator
import torch
from torch import nn
class CNNBlock(nn.Module):def __init__(self, in_channels, out_channels, relu=False, down=True, use_dropout=False):#in_channels输入的通道数#out_channels输出的通道数#relu 是用relu还是leakyReLU#down 是下采样还是上采样#use_dropout是否是用dropoutself.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=True)if down elsenn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode='reflect'),nn.InstanceNorm2d(out_channels),nn.Dropout(0.5) if use_dropout,nn.ReLu() if relu else nn.LeakyReLu(0.2))
class Generator(nn.Module):def __init__(self):super().__init__(img_channels=3, feature=64)#encoder一开始没有batchnormself.down1 = nn.Sequential(nn.Conv2d(img_channels, feature, 4, 2, 1, padding_mode='reflect'),nn.LeakyReLU(0.2))self.down2 = CNNBLock(feature, feature * 2)self.down3 = CNNBlock(feature * 2, feature * 4)self.down4 = CNNBlock(feature * 4, feature * 8)self.down5 = CNNBlock(feature * 8, feature * 8)	self.down6 = CNNBlock(feature * 8, feature * 8)	self.down7 = CNNBlock(feature * 8, feature * 8)	self.bottleneck = nn.Sequential(nn.Conv2d(feature * 8, feature * 8, 4, 2, 1, padding_mode='reflect'),nn.ReLU())#decoder部分,仅有前三层是用dropoutself.up1 = CNNBlock(feature * 8, feature * 8, relu=True, down=False, use_dropout=True)self.up2 = CNNBlock(feature * 16, feature * 8, relu=True, down=False, use_dropout=True)self.up3 = CNNBlock(feature * 16, feature * 8, relu=True, down=False, use_dropout=True)	self.up4 = CNNBlock(feature * 16, feature * 8, relu=True, down=False)	self.up5 = CNNBlock(feature * 16, feature * 4, relu=True, down=False)	self.up6 = CNNBlock(feature * 8, feature * 2, relu=True, down=False)	self.up7 = CNNBlock(feature * 4, feature, relu=True, down=False)#最后一层是用Tanh作为激活函数self.fianl_conv = nn.Sequential(nn.ConvTranspose2d(feature * 2, img_channels, 4, 2, 1, padding_mode='reflect'),nn.Tanh())def forward(self, x):#x[N, img_channels, 256, 256]down1 = self.down1(x)#[N, 64, 128, 128]down2 = self.down2(down1)#[N, 128, 64, 64]down3 = self.down3(down2)#[N, 256, 32, 32]down4 = self.down4(down3)#[N, 512, 16, 16]down5 = self.down5(down4)#[N, 512, 8, 8]down6 = self.down6(down5)#[N, 512, 4, 4]down7 = self.down7(down6)#[N, 512, 2, 2]bottleneck = self.bottleneck(down7)#[N, 512, 1, 1]up1 = self.up1(bottleneck) #[N, 512, 2, 2]up2 = self.up2(torch.cat([up1, down7], 1))#[N, 512, 4, 4]up3 = self.up3(torch.cat([up2, down6], 1))#[N, 512, 8, 8]up4 = self.up4(torch.cat([up3, down5], 1))#[N, 512, 16, 16]up5 = self.up5(torch.cat([up4, down4], 1))#[N, 256, 32, 32]up6 = self.up6(torch.cat([up5, down3], 1))#[N, 128, 64, 64]up7 = self.up7(torch.cat([up6, down2], 1))#[N, 64, 128, 128]final_conv = self.final_conv(torch.cat([up7, down1]))#[N, 3, 128, 128]return final_conv	

Disciminator就是一个很简单的下采样

import torch
from torch import nn
class CNNBlock(nn.Module):def __init(self, in_channels, out_channels, stride):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 4, 2, stride, bias=False, padding_mode='reflect')nn.InstanceNorm2d(out_channels),nn.LeakyReLU(0.2))
class Discriminator(nn.Module):def __init__(self, img_channles, features=[64, 128, 256, 512]):super().__init__()self.init_conv = nn.Sequential(nn.Conv2d(img_channels * 2, features[0]),nn.LeakyReLU(0.2))layers = []in_channels = features[0]for feature in features:layers.append(CNNBlock(in_channels, feature, stride=1 if feature==features[-1] else 2))in_channels = featureself.model = nn.Sequential(*layers)def forward(self, x, y):x = torch.cat([x, y], 1)x = self.init_conv(x)return self.model(x)

准备数据集

当我们的数据集下载好以后,需要处理成Dataset类,方便DataLoader加载

#config.py
#这里会进行一些数据增强的操作
import albumentations as A
from albumentations.pytorch import ToTensorV2
#两种图片都会进行的操作
both_transforms = A.Compose([#调整大小,50%进行水平旋转A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5),], additional_targets = {"image0":"image"})
transform_only_input = A.Compose([A.ColorJitter(p=0.1),A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.),ToTensorV2()
])
transform_only_mask = A.Compose([A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.),ToTensorV2()
])
import torch
from torch.utisl.data import Dataset
import config
import numpy as np
import os
from PIL import Image
class AnimeDataset(Dataset):def __init__(self, root_dir):self.root_dir = root_dirself.paths = os.listdir(root_dir)def __len__(self):return len(self.paths)def __getitem__(self, index):img_path = os.path.join(self.root_dir, self.paths[index])img = np.array(Image.open(img_path))#图片总宽1024,右边512是input,左边512是目标 input_image = img[:, 512:, :]target_image = img[:, :512, :]augmentations = config.both_transforms(image=input_image, image0=target_image)input_image, target_image = augmentations['image'], augmentations['image0']input_image = config.transform_only_input(image=input_image)['image']target_image = config.transform_only_mask(image=target_image)['image']return input_image, target_image

一些训练的时候将会用到的工具

config就是一个配置文件,记录一些超参数还有数据增强的代码

#config.py
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = "cuda" if torch.cuda.is_available() else 'cpu'LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
NUM_EPOCHS = 500
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"
both_transforms = A.Compose([A.Resize(width=IMAGE_SIZE, height=IMAGE_SIZE), A.HorizontalFlip(p=0.5)], additional_targets={"image0":"image"}
)
transform_only_input = A.Compose([A.ColorJitter(p=0.1),A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5], max_pixel_value=255.0),ToTensorV2()]
)
transform_only_mask = A.Compose([A.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5], max_pixel_value=255.0),ToTensorV2()]
)

一些在训练中可能用到的工具

#utils.py
import torch
import config
from torchvision.image import save_imagedef save_some_examples(gen, val_loader, epoch, folder):x, y = next(iter(val_loader))x, y = x.to(config.DEVICE), y.to(config.DEVICE)gen.val()with torch.no_grad():y_fake = gen(x)y_fake = y_fake * 0.5 + 0.5save_image(y_fake, folder + f"/y_gen_{epoch}.png")save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")if epoch == 1:save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")gen.train()def save_checkpoint(model, optimizer, filename):checkpoint = {"model":model.state_dict(),"optimizer":optimizer.state_dict()}torch.save(checkpoint, filename)def load_checkpoint(model, optimizer, filename, lr):state_dict = torch.load(filename, map_location=config.DEVICE)model.load_state_dict(state_dict['model'])optimizer.load_state_dict(state_dict['optimizer'])for param_group in optimizer.param_group():param_group['lr'] = lr

训练过程

import torch
from torch import optim, nn
from generator import Generator
from discriminator import Discriminator
from dataset import AnimeDataset
from torch.utils.data import DataLoader
from utils import *
import config
from tqdm import tqdmdisc = Discriminator(img_channels=conifg.CHANNELS_IMAGE).to(config.DEVICE)
gen = Generator(img_channels=config.CHANNELS_IMAGE).to(config.DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5,0.999))
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5,0.999))
train_set = AnimeDataset(root_dir='../data/anime/train')
train_loader = DataLoader(train_set, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUN_WORKERS)
val_set = AnimeDataset(root_dir='../data/anime/val')
val_loader = DataLoader(val_set, batch_size=1, shuffle=False)bce = nn.BCEWithLogitsLoss() #损失函数依旧是js散度
L1_loss = nn.L1Loss() #l1正则项
if config.LOAD_MODEL:load_checkpoint(gen, opt_gen, config.CHECKPOINT_GEN, config.LEARNING_RATE)load_checkpoint(disc, opt_disc, config.CHECKPOINT_GEN, config.LEARNING_RATE)for epoch in range(1, config.NUM_EPOCHS + 1):loop = tqdm(train_loader ,leave=True)#显示训练进度条for x, y in loop:x ,y = x.to(config.DEVICE), y.to(config.DEVICE)y_fake = gen(x)disc_real = disc(x ,y).reshape(-1)disc_fake = disc(x, y_fake.detach()).reshape(-1)lossD_real = bce(disc_real, torch.ones_like(disc_real))lossD_fake = bce(disc_fake, torch.zeros_like(disc_fake))lossD = lossD_real + lossD_fakeopt_disc.zero_grad()lossD.backward()opt_disc.step()D_fake = disc(x, y_fake).reshape(-1)G_fake_loss = bce(D_fake, torch.ones_like(D_fake))L1 = l1(y_fake, y) * config.L1_LAMBDAG_loss = G_fake_loss + L1opt_gen.zero_grad()G_loss.backward()opt_gen.step()if config.SAVE_MODEL and epoch % 5 == 0:save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)save_some_examples(gen, val_loader, epoch, folder='evaluation')

更多推荐

Pix2Pix(从零实现图像风格转换任务)

本文发布于:2023-07-28 18:48:35,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1278873.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:图像   风格   Pix2Pix

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!