网络模型(Seq2Seq

编程入门 行业动态 更新时间:2024-10-28 00:17:02

网络<a href=https://www.elefans.com/category/jswz/34/1771358.html style=模型(Seq2Seq"/>

网络模型(Seq2Seq

概念

用于处理序列问题:翻译(N vs N)、信息提取(N vs 1)、生成(1 vs N)。

RNN 要求输入队列和输出队列等长,Seq2Seq 可以解决输入队列与输出队列不等长的问题。

实验(验证码识别)

数据集:生成 4 位数字的验证码图片(测试集和训练集各 1000 张),图片名称为 index.code.jpg,截取 code 作为标签。

网络结构:

  • 编码:全连接 + 标准化(BN)+ 激活(ReLU)+ LSTM。
  • 解码:LSTM + 全连接 + softmax(多分类)。

优化器:Adam。

损失函数:均方差(MSELoss)。

输出:4 个 one-hot 类型,结果为最大的索引值。

生成验证码

import random
from PIL import Image, ImageDraw, ImageFont# 随机数字
def rand_char():return chr(random.randint(48, 57))# 随机背景颜色
def rand_bg():return (random.randint(50, 150), random.randint(50, 150), random.randint(50, 150))# 随机数字颜色
def rand_color():return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))width = 240
height = 60
font = ImageFont.truetype("arial.ttf", size=36)
for i in range(1000):img = Image.new("RGB", (width, height), (255, 255, 255))draw = ImageDraw.ImageDraw(img)# 画背景for x in range(width):for y in range(height):draw.point((x, y), rand_bg())# 写数字chrs = []for n in range(4):each = rand_char()chrs.append(each)draw.text((n * 60 + 10, 10), each, rand_color(), font)image = image.filter(ImageFilter.BLUR)img.save("data/{}.{}{}{}{}.jpg".format(i, chrs[0], chrs[1], chrs[2], chrs[3]))img.save("test/{}.{}{}{}{}.jpg".format(i, chrs[0], chrs[1], chrs[2], chrs[3]))

数据集

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import os
from PIL import Imageclass MyDataset(Dataset):def __init__(self, path):# 数据标准化self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])self.imgs = os.listdir(path)self.path = pathdef __len__(self):return len(self.imgs)def __getitem__(self, index):img = Image.open(os.path.join(self.path, self.imgs[index]))img = self.transform(img)label = self.imgs[index].split(".")[1]label = self.one_hot(label)return img, label# 把标签转为 one-hot 格式def one_hot(self, x):result = torch.zeros(4, 10)for i in range(4):result[i][int(x[i])] = 1return result

网络

import torch
from torch import nn
from torch.nn import functional as f# 编码器
class Encoder(nn.Module):def __init__(self):super().__init__()# 全连接 + 标准化(BN) + 激活(ReLU)self.mlp = nn.Sequential(nn.Linear(180, 128), nn.BatchNorm1d(128), nn.ReLU())self.lstm = nn.LSTM(128, 128, 2, batch_first=True)def forward(self, x):# [n,c,h,w] → [n,c*h,w] (验证码是横向的,所以竖着切)x = x.reshape(-1, 180, 240)# [n,c*h,w] → [n,w,c*h]x = x.permute(0, 2, 1)# [n,w,c*h] → [n*w,c*h] (把 c*h 作为输入参数)x = x.reshape(-1, 180)out = self.mlp(x)# [n*w,128] → [n,w,128] (w 是数据长度,要切 w 次,有 w 个输出)out = out.reshape(-1, 240, 128)out, _ = self.lstm(out)# [n,w,128] → [n,128] (取最后一个输出)out = out[:, -1, :]return out# 解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.lstm = nn.LSTM(128, 128, 2, batch_first=True)# 输出层:全连接,返回(4 个值)self.mlp = nn.Linear(128, 10)def forward(self, x):# [n,128] → [n,1,128]x = x.reshape(-1, 1, 128)# [n,1,128] → [n,4,128] (输入长度为 4 的数据)x = x.expand(-1, 4, 128)out, _ = self.lstm(x)# [n,4,128] → [n*4,128]out = out.reshape(-1, 128)out = self.mlp(out)# [n*4,10] → [n,4,10]out = out.reshape(-1, 4, 10)# 输出层:返回 one-hot 类型(4 个十分类)out = f.softmax(out, 2)return out# 主网络
class MyNet(nn.Module):def __init__(self):super().__init__()self.encoder = Encoder()self.decoder = Decoder()def forward(self, x):out = self.encoder(x)out = self.decoder(out)return out

训练

from dataset import MyDataset
from net import MyNetimport torch
from torch import nn
from torch.utils.data import DataLoader
import os
import numpy as npbatch_size = 100
net_path = r"modules/mynet.pth"is_train = True# 数据集
train_path = r"data/train_dataset"
test_path = r"data/test_dataset"
if is_train:dataset = MyDataset(train_path)dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=4)
else:dataset = MyDataset(test_path)dataloader = DataLoader(dataset, batch_size, shuffle=False)if __name__ == '__main__':# 加载网络if os.path.isfile(net_path):net = torch.load(net_path)else:net = MyNet()opt = torch.optim.Adam([{"params": net.encoder.parameters()}, {"params": net.decoder.parameters()}])loss_fn = nn.MSELoss()if is_train:# 训练net.train()while True:for i, (x, y) in enumerate(dataloader):out = net(x)loss = loss_fn(out, y)opt.zero_grad()loss.backward()opt.step()# 结果是 one-hot 类型,取最大索引result = torch.argmax(out, 2).numpy()label = torch.argmax(y, 2).numpy()acc = np.mean(np.all(result == label, axis=1))print("i:{},loss:{:.5},acc:{:.3}".format(i, loss, acc))# 保存网络torch.save(net, net_path)else:# 测试net.eval()for x, y in dataloader:out = net(x)result = torch.argmax(out[0], 1)print("result:{}".format(result))label = torch.argmax(y[0], 1)print("label:{}".format(label))

更多推荐

网络模型(Seq2Seq

本文发布于:2024-02-17 09:14:19,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1693487.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:模型   网络   Seq2Seq

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!