YOLOX训练残损数据集(仅适合内部人员参考)

编程入门 行业动态 更新时间:2024-10-09 01:16:29

YOLOX训练残损数据集(仅<a href=https://www.elefans.com/category/jswz/34/1769634.html style=适合内部人员参考)"/>

YOLOX训练残损数据集(仅适合内部人员参考)

以下为修改代码的所有文件,便于直接复制使用。

1、YOLOX-main\exps\example\yolox_voc\yolox_voc_damage_m.py

# encoding: utf-8
import os
import random
import torch
import torch.nn as nn
import torch.distributed as distfrom yolox.exp import Exp as MyExp
from yolox.data import get_yolox_datadir
# damage数据集
class Exp(MyExp):def __init__(self):super(Exp, self).__init__()self.num_classes = 3self.depth = 0.67self.width = 0.75self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]def get_data_loader(self, batch_size, is_distributed, no_aug=False):from yolox.data import (VOCDetection,TrainTransform,YoloBatchSampler,DataLoader,InfiniteSampler,MosaicDetection,)# damage数据集dataset = VOCDetection(data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),#image_sets=[('2007', 'trainval'), ('2012', 'trainval')],train_or_test_txt="cansun_train.txt",img_size=self.input_size,preproc=TrainTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),max_labels=50,),)dataset = MosaicDetection(dataset,mosaic=not no_aug,img_size=self.input_size,preproc=TrainTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),max_labels=120,),degrees=self.degrees,translate=self.translate,scale=self.scale,shear=self.shear,perspective=self.perspective,enable_mixup=self.enable_mixup,)self.dataset = datasetif is_distributed:batch_size = batch_size // dist.get_world_size()sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)batch_sampler = YoloBatchSampler(sampler=sampler,batch_size=batch_size,drop_last=False,input_dimension=self.input_size,mosaic=not no_aug,)dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}dataloader_kwargs["batch_sampler"] = batch_samplertrain_loader = DataLoader(self.dataset, **dataloader_kwargs)return train_loaderdef get_eval_loader(self, batch_size, is_distributed, testdev=False):from yolox.data import VOCDetection, ValTransformvaldataset = VOCDetection(data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),image_sets=[('2007', 'test')],train_or_test_txt="cansun_test.txt",img_size=self.test_size,preproc=ValTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),),)if is_distributed:batch_size = batch_size // dist.get_world_size()sampler = torch.utils.data.distributed.DistributedSampler(valdataset, shuffle=False)else:sampler = torch.utils.data.SequentialSampler(valdataset)dataloader_kwargs = {"num_workers": self.data_num_workers,"pin_memory": True,"sampler": sampler,}dataloader_kwargs["batch_size"] = batch_sizeval_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)return val_loaderdef get_evaluator(self, batch_size, is_distributed, testdev=False):from yolox.evaluators import VOCEvaluatorval_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)evaluator = VOCEvaluator(dataloader=val_loader,img_size=self.test_size,confthre=self.test_conf,nmsthre=self.nmsthre,num_classes=self.num_classes,)return evaluator

2、YOLOX-main/yolox/data/datasets/voc_classes.py, 该文件为类别的设定(具体类别已模糊化处理,详情请参考txt文件)

#  damage数据集
VOC_CLASSES = ("damage0","damage1","damage2",
)

3、YOLOX-main/yolox/data/datasets/voc.py

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Code are based on
# .py
# Copyright (c) Francisco Massa.
# Copyright (c) Ellis Brown, Max deGroot.
# Copyright (c) Megvii, Inc. and its affiliates.import os
import os.path
import pickle
import xml.etree.ElementTree as ETimport cv2
import numpy as npfrom yolox.evaluators.voc_eval import voc_evalfrom .datasets_wrapper import Dataset
from .voc_classes import VOC_CLASSESclass AnnotationTransform(object):"""Transforms a VOC annotation into a Tensor of bbox coords and label indexInitilized with a dictionary lookup of classnames to indexesArguments:class_to_ind (dict, optional): dictionary lookup of classnames -> indexes(default: alphabetic indexing of VOC's 20 classes)keep_difficult (bool, optional): keep difficult instances or not(default: False)height (int): heightwidth (int): width"""def __init__(self, class_to_ind=None, keep_difficult=True):self.class_to_ind = class_to_ind or dict(zip(VOC_CLASSES, range(len(VOC_CLASSES))))self.keep_difficult = keep_difficultdef __call__(self, target):"""Arguments:target (annotation) : the target annotation to be made usablewill be an ET.ElementReturns:a list containing lists of bounding boxes  [bbox coords, class name]"""res = np.empty((0, 5))for obj in target.iter("object"):difficult = int(obj.find("difficult").text) == 1if not self.keep_difficult and difficult:continuename = obj.find("name").text.lower().strip()bbox = obj.find("bndbox")pts = ["xmin", "ymin", "xmax", "ymax"]bndbox = []for i, pt in enumerate(pts):cur_pt = int(bbox.find(pt).text) - 1# scale height or width# cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / heightbndbox.append(cur_pt)label_idx = self.class_to_ind[name]bndbox.append(label_idx)res = np.vstack((res, bndbox))  # [xmin, ymin, xmax, ymax, label_ind]# img_id = target.find('filename').text[:-4]return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]class VOCDetection(Dataset):"""VOC Detection Dataset Objectinput is image, target is annotationArgs:root (string): filepath to VOCdevkit folder.image_set (string): imageset to use (eg. 'train', 'val', 'test')transform (callable, optional): transformation to perform on theinput imagetarget_transform (callable, optional): transformation to perform on thetarget `annotation`(eg: take in caption string, return tensor of word indices)dataset_name (string, optional): which dataset to load(default: 'VOC2007')"""def __init__(self,data_dir,image_sets=[('2007', 'trainval'), ('2012', 'trainval')],train_or_test_txt="cansun_train.txt",img_size=(416, 416),preproc=None,target_transform=AnnotationTransform(),#dataset_name="VOC0712",dataset_name="damage",):super().__init__(img_size)self.root = data_dirself.image_set = image_setsself.train_or_test_txt = train_or_test_txtself.img_size = img_sizeself.preproc = preprocself.target_transform = target_transformself.name = dataset_nameself._annopath = os.path.join("%s", "Annotations", "%s.xml")self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")self._classes = VOC_CLASSESself.ids = list()'''for (year, name) in image_sets:self._year = yearrootpath = os.path.join(self.root, "VOC" + year)for line in open(os.path.join(rootpath, "ImageSets", "Main", name + ".txt")):self.ids.append((rootpath, line.strip()))'''# damage数据集rootpath = os.path.join(self.root, self.name)for line in open(os.path.join(rootpath, self.train_or_test_txt)):self.ids.append((rootpath, line.strip()))def __len__(self):return len(self.ids)def load_anno(self, index):img_id = self.ids[index]target = ET.parse(self._annopath % img_id).getroot()if self.target_transform is not None:target = self.target_transform(target)return targetdef pull_item(self, index):"""Returns the original image and target at an index for mixupNote: not using self.__getitem__(), as any transformations passed incould mess up this functionality.Argument:index (int): index of img to showReturn:img, target"""img_id = self.ids[index]img = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)height, width, _ = img.shapetarget = self.load_anno(index)img_info = (width, height)return img, target, img_info, index@Dataset.resize_getitemdef __getitem__(self, index):img, target, img_info, img_id = self.pull_item(index)if self.preproc is not None:img, target = self.preproc(img, target, self.input_dim)return img, target, img_info, img_iddef evaluate_detections(self, all_boxes, output_dir=None):"""all_boxes is a list of length number-of-classes.Each list element is a list of length number-of-images.Each of those list elements is either an empty list []or a numpy array of detection.all_boxes[class][image] = [] or np.array of shape #dets x 5"""self._write_voc_results_file(all_boxes)IouTh = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True)mAPs = []for iou in IouTh:mAP = self._do_python_eval(output_dir, iou)mAPs.append(mAP)print("--------------------------------------------------------------")print("map_5095:", np.mean(mAPs))print("map_50:", mAPs[0])print("--------------------------------------------------------------")return np.mean(mAPs), mAPs[0]def _get_voc_results_file_template(self):filename = "comp4_det_test" + "_{:s}.txt"#filedir = os.path.join(self.root, "results", "VOC" + self._year, "Main")filedir = os.path.join(self.root, "results", "VOC", "Main")if not os.path.exists(filedir):os.makedirs(filedir)path = os.path.join(filedir, filename)return pathdef _write_voc_results_file(self, all_boxes):for cls_ind, cls in enumerate(VOC_CLASSES):cls_ind = cls_indif cls == "__background__":continueprint("Writing {} VOC results file".format(cls))filename = self._get_voc_results_file_template().format(cls)with open(filename, "wt") as f:for im_ind, index in enumerate(self.ids):index = index[1]dets = all_boxes[cls_ind][im_ind]if dets == []:continuefor k in range(dets.shape[0]):f.write("{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(index,dets[k, -1],dets[k, 0] + 1,dets[k, 1] + 1,dets[k, 2] + 1,dets[k, 3] + 1,))def _do_python_eval(self, output_dir="output", iou=0.5):#rootpath = os.path.join(self.root, "VOC" + self._year)name = self.image_set[0][1]#annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")annopath = os.path.join(self.root, "damage", "Annotations", "{:s}.xml")#imagesetfile = os.path.join(rootpath, "ImageSets", "Main", name + ".txt")imagesetfile = os.path.join(self.root, "damage", self.train_or_test_txt)#cachedir = os.path.join(#    self.root, "annotations_cache", "VOC" + self._year, name#)cachedir = os.path.join(self.root, "annotations_cache", "VOC", name)if not os.path.exists(cachedir):os.makedirs(cachedir)aps = []# The PASCAL VOC metric changed in 2010#use_07_metric = True if int(self._year) < 2010 else Falseuse_07_metric = Trueprint("Eval IoU : {:.2f}".format(iou))if output_dir is not None and not os.path.isdir(output_dir):os.mkdir(output_dir)for i, cls in enumerate(VOC_CLASSES):if cls == "__background__":continuefilename = self._get_voc_results_file_template().format(cls)rec, prec, ap = voc_eval(filename,annopath,imagesetfile,cls,cachedir,ovthresh=iou,use_07_metric=use_07_metric,)aps += [ap]if iou == 0.5:print("AP for {} = {:.4f}".format(cls, ap))if output_dir is not None:with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)if iou == 0.5:print("Mean AP = {:.4f}".format(np.mean(aps)))print("~~~~~~~~")print("Results:")for ap in aps:print("{:.3f}".format(ap))print("{:.3f}".format(np.mean(aps)))print("~~~~~~~~")print("")print("--------------------------------------------------------------")print("Results computed with the **unofficial** Python eval code.")print("Results should be very close to the official MATLAB eval code.")print("Recompute with `./tools/reval.py --matlab ...` for your paper.")print("-- Thanks, The Management")print("--------------------------------------------------------------")return np.mean(aps)

4、推理时修改COCO数据集为VOC数据集,YOLOX-main/yolox/data/datasets/__init__.py

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.from .coco import COCODataset
from .coco_classes import COCO_CLASSES
from .voc_classes import VOC_CLASSES
from .datasets_wrapper import ConcatDataset, Dataset, MixConcatDataset
from .mosaicdetection import MosaicDetection
from .voc import VOCDetection    #新加

5、YOLOX-main/tools/demo.py

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.import argparse
import os
import time
from loguru import loggerimport cv2import torchfrom yolox.data.data_augment import preproc
from yolox.data.datasets import VOC_CLASSES
#from yolox.data.datasets import VOC_CLASSES
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, visIMAGE_EXT = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']def make_parser():parser = argparse.ArgumentParser("YOLOX Demo!")parser.add_argument('demo', default='image', help='demo type, eg. image, video and webcam')parser.add_argument("-expn", "--experiment-name", type=str, default=None)parser.add_argument("-n", "--name", type=str, default=None, help="model name")parser.add_argument('--path', default='./assets/dog.jpg', help='path to images or video')parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')parser.add_argument('--save_result', action='store_true',help='whether to save the inference result of image/video')# exp fileparser.add_argument("-f","--exp_file",default=None,type=str,help="pls input your expriment description file",)parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")parser.add_argument("--device", default="cpu", type=str, help="device to run our model, can either be cpu or gpu")parser.add_argument("--conf", default=None, type=float, help="test conf")parser.add_argument("--nms", default=None, type=float, help="test nms threshold")parser.add_argument("--tsize", default=None, type=int, help="test img size")parser.add_argument("--fp16",dest="fp16",default=False,action="store_true",help="Adopting mix precision evaluating.",)parser.add_argument("--fuse",dest="fuse",default=False,action="store_true",help="Fuse conv and bn for testing.",)parser.add_argument("--trt",dest="trt",default=False,action="store_true",help="Using TensorRT model for testing.",)return parserdef get_image_list(path):image_names = []for maindir, subdir, file_name_list in os.walk(path):for filename in file_name_list:apath = os.path.join(maindir, filename)ext = os.path.splitext(apath)[1]if ext in IMAGE_EXT:image_names.append(apath)return image_namesclass Predictor(object):#def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu"):def __init__(self, model, exp, cls_names=VOC_CLASSES, trt_file=None, decoder=None, device="cpu"):self.model = modelself.cls_names = cls_namesself.decoder = decoderself.num_classes = exp.num_classesself.confthre = exp.test_confself.nmsthre = exp.nmsthreself.test_size = exp.test_sizeself.device = deviceif trt_file is not None:from torch2trt import TRTModulemodel_trt = TRTModule()model_trt.load_state_dict(torch.load(trt_file))x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()self.model(x)self.model = model_trtself.rgb_means = (0.485, 0.456, 0.406)self.std = (0.229, 0.224, 0.225)def inference(self, img):img_info = {'id': 0}if isinstance(img, str):img_info['file_name'] = os.path.basename(img)img = cv2.imread(img)else:img_info['file_name'] = Noneheight, width = img.shape[:2]img_info['height'] = heightimg_info['width'] = widthimg_info['raw_img'] = imgimg, ratio = preproc(img, self.test_size, self.rgb_means, self.std)img_info['ratio'] = ratioimg = torch.from_numpy(img).unsqueeze(0)if self.device == "gpu":img = img.cuda()with torch.no_grad():t0 = time.time()outputs = self.model(img)if self.decoder is not None:outputs = self.decoder(outputs, dtype=outputs.type())outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)logger.info('Infer time: {:.4f}s'.format(time.time()-t0))return outputs, img_infodef visual(self, output, img_info, cls_conf=0.35):ratio = img_info['ratio']img = img_info['raw_img']if output is None:return imgoutput = output.cpu()bboxes = output[:, 0:4]# preprocessing: resizebboxes /= ratiocls = output[:, 6]scores = output[:, 4] * output[:, 5]vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)return vis_resdef image_demo(predictor, vis_folder, path, current_time, save_result):if os.path.isdir(path):files = get_image_list(path)else:files = [path]files.sort()for image_name in files:outputs, img_info = predictor.inference(image_name)result_image = predictor.visual(outputs[0], img_info, predictor.confthre)if save_result:save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))os.makedirs(save_folder, exist_ok=True)save_file_name = os.path.join(save_folder, os.path.basename(image_name))logger.info("Saving detection result in {}".format(save_file_name))cv2.imwrite(save_file_name, result_image)ch = cv2.waitKey(0)if ch == 27 or ch == ord('q') or ch == ord('Q'):breakdef imageflow_demo(predictor, vis_folder, current_time, args):cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # floatheight = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # floatfps = cap.get(cv2.CAP_PROP_FPS)save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))os.makedirs(save_folder, exist_ok=True)if args.demo == "video":save_path = os.path.join(save_folder, args.path.split('/')[-1])else:save_path = os.path.join(save_folder, 'camera.mp4')logger.info(f'video save_path is {save_path}')vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height)))while True:ret_val, frame = cap.read()if ret_val:outputs, img_info = predictor.inference(frame)result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)if args.save_result:vid_writer.write(result_frame)ch = cv2.waitKey(1)if ch == 27 or ch == ord('q') or ch == ord('Q'):breakelse:breakdef main(exp, args):if not args.experiment_name:args.experiment_name = exp.exp_namefile_name = os.path.join(exp.output_dir, args.experiment_name)os.makedirs(file_name, exist_ok=True)if args.save_result:vis_folder = os.path.join(file_name, 'vis_res')os.makedirs(vis_folder, exist_ok=True)if args.trt:args.device="gpu"logger.info("Args: {}".format(args))if args.conf is not None:exp.test_conf = args.confif args.nms is not None:exp.nmsthre = args.nmsif args.tsize is not None:exp.test_size = (args.tsize, args.tsize)model = exp.get_model()logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))if args.device == "gpu":model.cuda()model.eval()if not args.trt:if args.ckpt is None:ckpt_file = os.path.join(file_name, "best_ckpt.pth.tar")else:ckpt_file = args.ckptlogger.info("loading checkpoint")ckpt = torch.load(ckpt_file, map_location="cpu")# load the model state dictmodel.load_state_dict(ckpt["model"])logger.info("loaded checkpoint done.")if args.fuse:logger.info("\tFusing model...")model = fuse_model(model)if args.trt:assert (not args.fuse),\"TensorRT model is not support model fusing!"trt_file = os.path.join(file_name, "model_trt.pth")assert os.path.exists(trt_file), ("TensorRT model is not found!\n Run python3 tools/trt.py first!")model.head.decode_in_inference = Falsedecoder = model.head.decode_outputslogger.info("Using TensorRT to inference")else:trt_file = Nonedecoder = None#predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device)predictor = Predictor(model, exp, VOC_CLASSES, trt_file, decoder, args.device)current_time = time.localtime()if args.demo == 'image':image_demo(predictor, vis_folder, args.path, current_time, args.save_result)elif args.demo == 'video' or args.demo == 'webcam':imageflow_demo(predictor, vis_folder, current_time, args)if __name__ == "__main__":args = make_parser().parse_args()exp = get_exp(args.exp_file, args.name)main(exp, args)

更多推荐

YOLOX训练残损数据集(仅适合内部人员参考)

本文发布于:2024-02-19 14:42:24,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1764569.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:适合   人员   数据   YOLOX

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!