PaddlePaddle入门实践——十二生肖分类

编程入门 行业动态 更新时间:2024-10-23 15:30:37

PaddlePaddle入门实践——十二<a href=https://www.elefans.com/category/jswz/34/1766988.html style=生肖分类"/>

PaddlePaddle入门实践——十二生肖分类

十二生肖分类

    • 任务要求
    • 图像分类
      • 实现思路
      • 图像分类原理
    • 数据准备
      • 解压数据集
      • 数据标注
      • 数据集定义
    • 模型开发
    • 模型训练优化
    • 模型评估
    • 参考

任务要求

  找到一个最优算法,让机器能够分清每个属相动物的照片,这是一个基于图像的分类任务

图像分类

实现思路

图像分类原理

数据准备

解压数据集

  我们将网上获取的数据集以压缩包的方式上传到aistudio数据集中,并加载到我们的项目内。在使用之前我们进行数据集压缩包的一个解压(十二生肖数据集

!unzip -q -o <压缩包路径>

数据标注

  数据集结构:

.
├── test
│   ├── dog
│   ├── dragon
│   ├── goat
│   ├── horse
│   ├── monkey
│   ├── ox
│   ├── pig
│   ├── rabbit
│   ├── ratt
│   ├── rooster
│   ├── snake
│   └── tiger
├── train
│   ├── dog
│   ├── dragon
│   ├── goat
│   ├── horse
│   ├── monkey
│   ├── ox
│   ├── pig
│   ├── rabbit
│   ├── ratt
│   ├── rooster
│   ├── snake
│   └── tiger
└── valid├── dog├── dragon├── goat├── horse├── monkey├── ox├── pig├── rabbit├── ratt├── rooster├── snake└── tiger

  数据集分为train、valid、test三个文件夹,每个文件夹内包含12个分类文件夹,每个分类文件夹内是具体的样本图片。我们对这些样本进行一个标注处理,最终生成train.txt/valid.txt/test.txt三个数据标注文件。

import io
import os
from PIL import Image
from config import get# 数据集根目录
DATA_ROOT = 'signs'# 标签List
LABEL_MAP = get('LABEL_MAP')# 标注生成函数
def generate_annotation(mode):# 建立标注文件with open('{}/{}.txt'.format(DATA_ROOT, mode), 'w') as f:# 对应每个用途的数据文件夹,train/valid/testtrain_dir = '{}/{}'.format(DATA_ROOT, mode)# 遍历文件夹,获取里面的分类文件夹for path in os.listdir(train_dir):# 标签对应的数字索引,实际标注的时候直接使用数字索引label_index = LABEL_MAP.index(path)# 图像样本所在的路径image_path = '{}/{}'.format(train_dir, path)# 遍历所有图像for image in os.listdir(image_path):# 图像完整路径和名称image_file = '{}/{}'.format(image_path, image)try:# 验证图片格式是否okwith open(image_file, 'rb') as f_img:image = Image.open(io.BytesIO(f_img.read()))image.load()if image.mode == 'RGB':f.write('{}\t{}\n'.format(image_file, label_index))except:continuegenerate_annotation('train')  # 生成训练集标注文件
generate_annotation('valid')  # 生成验证集标注文件
generate_annotation('test')   # 生成测试集标注文件

数据集定义

  接下来我们使用标注好的文件进行数据集类的定义,方便后续模型训练使用。

import paddle
import numpy as np
from config import get# 导入数据集的定义实现
from dataset import ZodiacDataset# 实例化数据集类
train_dataset = ZodiacDataset(mode='train')
valid_dataset = ZodiacDataset(mode='valid')print('训练数据集:{}张;验证数据集:{}张'.format(len(train_dataset), len(valid_dataset)))

导入数据集dataset.py文件

import paddle
import paddle.vision.transforms as T
import numpy as np
from config import get
from PIL import Image__all__ = ['ZodiacDataset']# 定义图像的大小
image_shape = get('image_shape')
IMAGE_SIZE = (image_shape[1], image_shape[2])class ZodiacDataset(paddle.io.Dataset):"""十二生肖数据集类的定义"""def __init__(self, mode='train'):"""初始化函数"""assert mode in ['train', 'test', 'valid'], 'mode is one of train, test, valid.'self.data = []with open('signs/{}.txt'.format(mode)) as f:for line in f.readlines():info = line.strip().split('\t')if len(info) > 0:self.data.append([info[0].strip(), info[1].strip()])if mode == 'train':self.transforms = T.Compose([T.RandomResizedCrop(IMAGE_SIZE),    # 随机裁剪大小T.RandomHorizontalFlip(0.5),        # 随机水平翻转T.ToTensor(),                       # 数据的格式转换和标准化 HWC => CHW  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 图像归一化])else:self.transforms = T.Compose([T.Resize(256),                 # 图像大小修改T.RandomCrop(IMAGE_SIZE),      # 随机裁剪T.ToTensor(),                  # 数据的格式转换和标准化 HWC => CHWT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   # 图像归一化])def __getitem__(self, index):"""根据索引获取单个样本"""image_file, label = self.data[index]image = Image.open(image_file)if image.mode != 'RGB':image = image.convert('RGB')image = self.transforms(image)return image, np.array(label, dtype='int64')def __len__(self):"""获取样本总数"""return len(self.data)

config.py配置项

__all__ = ['CONFIG', 'get']CONFIG = {'model_save_dir': "./output/zodiac",'num_classes': 12,'total_images': 7096,'epochs': 20,'batch_size': 32,'image_shape': [3, 224, 224],'LEARNING_RATE': {'params': {'lr': 0.00375             }},'OPTIMIZER': {'params': {'momentum': 0.9},'regularizer': {'function': 'L2','factor': 0.000001}},'LABEL_MAP': ["ratt","ox","tiger","rabbit","dragon","snake","horse","goat","monkey","rooster","dog","pig",]
}def get(full_path):for id, name in enumerate(full_path.split('.')):if id == 0:config = CONFIGconfig = config[name]return config

模型开发

  采用ResNet50网络构建模型。模型结构比较复杂,但PaddlePaddle高层API内置了paddle.vision.models.resnet50接口,能让我们一行代码生成模型,便于更快体验效果。
  下图为ResNet50结构:

  paddle.vision.models.resnet50内的pretrained参数设置True,目的是加载在imagenet数据集上的预训练权重,即基于前人训练效果很好的参数来迭代训练我们自己的模型。

network = paddle.vision.models.resnet50(num_classes=get('num_classes'), pretrained=True)
model = paddle.Model(network)
model.summary((-1, ) + tuple(get('image_shape')))

模型训练优化

EPOCHS = get('epochs')
BATCH_SIZE = get('batch_size')def create_optim(parameters):step_each_epoch = get('total_images') // get('batch_size')lr = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=get('LEARNING_RATE.params.lr'),T_max=step_each_epoch * EPOCHS)return paddle.optimizer.Momentum(learning_rate=lr,parameters=parameters,weight_decay=paddle.regularizer.L2Decay(get('OPTIMIZER.regularizer.factor')))# 模型训练配置
model.prepare(create_optim(network.parameters()),  # 优化器paddle.nn.CrossEntropyLoss(),        # 损失函数paddle.metric.Accuracy(topk=(1, 5))) # 评估指标# 训练可视化VisualDL可视化工具的回调函数
visualdl = paddle.callbacks.VisualDL(log_dir='visuald_log')#启动模型全流程训练
model.fit(train_dataset,             # 训练数据集valid_dataset,             # 评估数据集epochs=EPOCHS,             # 总的训练轮次batch_size=BATCH_SIZE,     # 批次样本的样本量大小shuffle=True,              # 是否打乱样本集verbose=1,                 # 日志展示格式save_dir='./chk_points/',  # 分阶段的训练模型存储路径callbacks=[visualdl])      # 回调函数


可视化VisualDL

模型评估

批量预测

predict_dataset = ZodiacDataset(mode='test')
print('测试数据集样本量:{}'.format(len(predict_dataset)))from paddle.static import InputSpec# 网络结构
network = paddle.vision.models.resnet50(num_classes=get('num_classes'))# 模型封装
model_2 = paddle.Model(network, inputs=[InputSpec(shape=[-1] + get('image_shape'), dtype='float32', name='image')])# 模型文件加载
model_2.load(get('model_save_dir'))
# 模型配置
model_2.prepare()# 执行预测
result = model_2.predict(predict_dataset)# 样本映射
LABEL_MAP = get('LABEL_MAP')# 随机取样本展示
indexs = [2, 38, 56, 92, 100, 303]for idx in indexs:predict_label = np.argmax(result[0][idx])real_label = predict_dataset[idx][1]print('样本ID:{}, 真实标签:{}, 预测值:{}'.format(idx, LABEL_MAP[real_label], LABEL_MAP[predict_label]))

参考

  • 十二生肖数据集
  • PaddlePaddle深度学习入门项目

更多推荐

PaddlePaddle入门实践——十二生肖分类

本文发布于:2024-02-07 07:21:09,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1755063.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:生肖   入门   PaddlePaddle

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!