admin管理员组文章数量:1609966
文章目录
- 数据集
- 数据集处理
- 迁移学习网络
- 原理
- 代码实现
数据集
使用宝可梦精灵的图片数据集。数据集地址:
- 链接:https://pan.baidu/s/1zDERMsV1AvwfZudhuae6Ew
- 提取码:rs4h
数据集中的每一类别的图片放在一个文件夹中
数据集共包含5个类别的图片,我们取每个文件夹(类别):
- 前60%做训练集
- 60%~80%做验证集
- 80%~100%做测试集
数据集处理
'''
load图片数据集
'''
import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
class Pokemon(Dataset):
def __init__(self, root, resize, mode):
'''
:param root: 数据集目录
:param resize: 图片的输出size
:param mode: train/val/test
'''
super(Pokemon, self).__init__()
self.root = root # 根目录
self.resize = resize # 图片的输出size
self.name2label = {} # 对目录名(类别)进行编码
for name in sorted(os.listdir(os.path.join(root))): # 遍历目录和文件
if not os.path.isdir(os.path.join(root, name)): # 如果不是目录(是图片)
continue
self.name2label[name] = len(self.name2label.keys()) # 用字典保存类别的编码
# print(self.name2label)
'''读入图片数据集'''
# image, label
self.images, self.labels = self.load_csv('images.csv')
'''划分train、val、test集'''
if mode=='train': # train: 60%
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif mode=='val': # val: 20% = 60%->80%
self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else: # test: 20% = 80%->100%
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
def load_csv(self, filename):
'''
一次加载进所有图片可能会造成内存不够用,因此我们可以把图片保存到一个csv文件
:param filename:保存的文件名
:return:
'''
# 如果csv文件不存在,就创建文件
# 如果csv文件存在,就是之前已经创建过,直接读取就好了
if not os.path.exists(os.path.join(self.root, filename)):
'''把所有的文件放到一个list中去。文件的class可以通过路径名来判定'''
images = []
for name in self.name2label.keys():
# 'pokemon\\mewtwo\\00001.png
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
print(len(images), images) # 1167
random.shuffle(images) # 打乱顺序
'''写入csv文件'''
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images: # 'pokemon\\bulbasaur\\00000000.png'
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])
# 'pokemon\\bulbasaur\\00000000.png', 0
print('writen into csv file:', filename)
'''read from csv file'''
images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
# 'pokemon\\bulbasaur\\00000000.png', 0
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels) # 检查条件,不符合就终止
return images, labels
def __len__(self):
'''
返回总体样本数量
:return:
'''
return len(self.images)
def denormalize(self, x_hat):
'''
逆标准化处理
:param x_hat: 标准化的tensor
:return: 逆标准化的tensor
'''
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x: [channel, high, wight]
# mean: [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
print(mean.shape, std.shape)
x = x_hat * std + mean
return x
def __getitem__(self, idx):
'''
取得当前位置图片
:param idx: 图片索引
:return:
'''
img, label = self.images[idx], self.labels[idx]
'''数据增强之后将图片转换为tensor'''
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'), # string path= > image data
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))), # 图片放大1.25倍
transforms.RandomRotation(15), # 随机旋转,在-15° ~ +15°之间
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # 标准化,这几个数是大范围统计出来的rgb三原色的均值和方差
std=[0.229, 0.224, 0.225])
])
# tf = transforms.Compose([
# lambda x:Image.open(x).convert('RGB'), # string path= > image data
# transforms.Resize((self.resize, self.resize)), # 图片放大1.25倍
# transforms.ToTensor(),
# ])
img = tf(img)
label = torch.tensor(label)
return img, label
def main():
'''
可视化查看数据集
此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''
import visdom
import time
import torchvision
viz = visdom.Visdom()
# 如果图片的存储很标准,可以用这种方法
# tf = transforms.Compose([
# transforms.Resize((64,64)),
# transforms.ToTensor(),
# ])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch_size=32, shuffle=True)
#
# print(db.class_to_idx)
#
# for x,y in loader:
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
#
# time.sleep(10)
# 通用的方法
db = Pokemon('pokemon', 64, 'train')
x,y = next(iter(db))
print('sample:', x.shape, y.shape, y)
# 加载一张图片
viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
# viz.image(x, win='sample_x', opts=dict(title='sample_x'))
# 加载一个batch的图片
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
for x, y in loader:
viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
viz.text(str(y.numpy()), win='label', opts=dict(title='batch-y'))
time.sleep(10)
if __name__ == '__main__':
main()
迁移学习网络
原理
Pokemon和ImageNet都需要图片中提取特征,因此存在某些共性的knowledge。因此我们可以利用更加通用的ImageNet的模型,帮我们解决特定的图片分类任务。
我们采用torchvision.models中训练好的resnet18,使用它训练好的卷积部分提取图像特征,并训练新的分类器处理我们提取到的特征。
这样我们只需要训练分类器,而不用再训练特征提取器,因此可以减少所需训练量。
代码实现
辅助文件:utils.py
from matplotlib import pyplot as plt
import torch
from torch import nn
'''
定义一个神经网络层
第一个维度保持,其他维度打平成一个维度
'''
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
'''
把image打印在matplotlab上
'''
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
实现网络构建,网络训练与评估的文件:train_transfer.py
'''
利用迁移学习
torchvision提供了训练好的resnet18、resnet34、resnet50...
此处需要安装并开启visdom
安装:pip install visdom
开启:python -m visdom.server
'''
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
from pokemon import Pokemon
from utils import Flatten
# 引入已经训练好的model
from torchvision.models import resnet18
batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)
train_db = Pokemon('pokemon', 224, mode='train')
val_db = Pokemon('pokemon', 224, mode='val')
test_db = Pokemon('pokemon', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
# 每次会开启num_work个线程,分别去加载dataset里面的数据,直到每个worker加载数据量为batch_size 大小(总共num_work*batch_size)才会进行下一步训练
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
viz = visdom.Visdom()
def evalute(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x,y in loader:
x,y = x.to(device), y.to(device)
with torch.no_grad(): # 不计算梯度
logits = model(x) # 前向运算
pred = logits.argmax(dim=1) # 选出输出层最大的元素
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def main():
'''初始化网络'''
trained_model = resnet18(pretrained=True) # 已经训练好的model
# x: [b, 3, 224, 224]
model = nn.Sequential(*list(trained_model.children())[:-1], # [b, 3, 224, 224] => [b, 512, 1, 1] # 取出从0到17层,作为特征提取器
Flatten(), # [b, 512, 1, 1] => [b, 512] # 自己定义的类,改变tensor维度
nn.Linear(512, 5) # [b, 512] => [b, 5] # 随机初始化的一个新的线性层,作为分类器
).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
'''记录实验结果参数'''
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
'''训练与评估'''
for epoch in range(epochs):
'''训练一次模型'''
for step, (x, y) in enumerate(train_loader): # 遍历
# x: [b, 3, 224, 224], y: [b]
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
# logits: [b, 5]
# y: [b]
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
'''评估模型'''
if epoch % 1 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(), 'best.mdl') # 保存评估结果最好的模型
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
'''加载最优模型'''
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
'''测试模型'''
test_acc = evalute(model, test_loader)
print('test acc:', test_acc)
if __name__ == '__main__':
main()
版权声明:本文标题:pytorch——迁移学习实战宝可梦精灵分类 内容由热心网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:https://www.elefans.com/xitong/1728585734a1164910.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论