Yolox 训练自己的数据集

编程入门 行业动态 更新时间:2024-10-07 20:25:07

Yolox 训练<a href=https://www.elefans.com/category/jswz/34/1771270.html style=自己的数据集"/>

Yolox 训练自己的数据集

Win10环境下Yolox训练自己的数据集

  • 目录
    • 1、文章链接和源码
    • 2、数据准备(VOC格式)
    • 3、配置文件修改
    • 4、训练
    • 5、使用自己的模型完成检测任务
      • 检测结果

目录

Yolo系列因为其灵活性,一直是目标检测热门算法,近期旷视的研究者提出了Yolox高性能目标检测器,将Anchor free引入了Yolo算法,是除YOLOV1之后,第二个将Anchor free研究思路用到Yolo的算法。
Yolox的创新点除了Anchor free思想之外,个人认为最重要的是解耦头(Decoupled head)的使用,在一定程度上解决了分类和回归任务问题;当然,除此之外,标签分配策略(SimOTA)也是其的一大亮点。

1、文章链接和源码

1. 文章链接 :.08430
2.
源码:

2、数据准备(VOC格式)

文章作者使用COCO数据集格式训练模型,在这里使用另一种格式VOC数据集来训练自己的模型

2.1 VOC数据集格式
Annotations文件夹放置.xml标签文件;JPEGImages文件夹放置训练原图;ImageSets文件夹下的Main文件夹存放训练、验证、测试数据集的.txt文件。 具体格式如下:

test.txt,train.txt,trainval.txt,val.txt文件内容如下图所示,存放的是各个图片的名称。

利用VOC格式训练Yolox模型这些.txt文件是必须的;相应的生成代码如下:

import os
import random 
random.seed(0)xmlfilepath='VOC数据集Annotations文件夹路径'#xml文件存放地址,在训练自己数据集的时候,改成自己的数据路径
saveBasePath="VOC数据集ImageSets\\Main文件夹路径"#存放test.txt,train.txt,trainval.txt,val.txt文件路径#----------------------------------------------------------------------#
#   根据自己的需求更改trainval_percent和train_percent的比例
#----------------------------------------------------------------------#
trainval_percent=0.9
train_percent=1temp_xml = os.listdir(xmlfilepath)
total_xml = []
for xml in temp_xml:if xml.endswith(".xml"):total_xml.append(xml)num=len(total_xml)  
list=range(num)  
tv=int(num*trainval_percent)  
tr=int(tv*train_percent)  
trainval= random.sample(list,tv)  
train=random.sample(trainval,tr)  print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')  
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')  
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')  for i  in list:  name=total_xml[i][:-4]+'\n'  if i in trainval:  ftrainval.write(name)  if i in train:  ftrain.write(name)  else:  fval.write(name)  else:  ftest.write(name)  ftrainval.close()  
ftrain.close()  
fval.close()  
ftest .close()

3、配置文件修改

3.1 修改检测类别
根据自己的需求修改检测类别名称,我做的项目是人脸检测,检测类别只有face,所以VOC_CLASS内容只写“face”。检测类别名称对应源码文件:D:\YOLOX-main\yolox\data\datasets\voc_classes.py,只修改该文件夹下的检测类别,各个类别间必须用逗号隔开,最后一个类别也必须加逗号,如下图所示:


3.2 修改训练参数
根据需求修改D:\YOLOX-main\yolox\exp\yolox_base.py文件下的相关参数,在这我只修改了self.num_classes = 1,其他参数可根据自己需求进行修改。

import os
import randomimport torch
import torch.distributed as dist
import torch.nn as nnfrom .base_exp import BaseExpclass Exp(BaseExp):def __init__(self):super().__init__()# ---------------- model config ---------------- #self.num_classes = 1self.depth = 1.00self.width = 1.00# ---------------- dataloader config ---------------- ## set worker to 4 for shorter dataloader init timeself.data_num_workers = 4self.input_size = (640, 640)self.random_size = (14, 26)self.data_dir = Noneself.train_ann = "instances_train2017.json"self.val_ann = "instances_val2017.json"# --------------- transform config ----------------- #self.degrees = 10.0self.translate = 0.1self.scale = (0.1, 2)self.mscale = (0.8, 1.6)self.shear = 2.0self.perspective = 0.0self.enable_mixup = True# --------------  training config --------------------- #self.warmup_epochs = 5self.max_epoch = 200self.warmup_lr = 0self.basic_lr_per_img = 0.01 / 64.0self.scheduler = "yoloxwarmcos"self.no_aug_epochs = 15self.min_lr_ratio = 0.05self.ema = Trueself.weight_decay = 5e-4self.momentum = 0.9self.print_interval = 10self.eval_interval = 10self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]# -----------------  testing config ------------------ #self.test_size = (640, 640)self.test_conf = 0.01self.nmsthre = 0.65def get_model(self):from yolox.models import YOLOX, YOLOPAFPN, YOLOXHeaddef init_yolo(M):for m in M.modules():if isinstance(m, nn.BatchNorm2d):m.eps = 1e-3m.momentum = 0.03if getattr(self, "model", None) is None:in_channels = [256, 512, 1024]backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels)head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels)self.model = YOLOX(backbone, head)self.model.apply(init_yolo)self.model.head.initialize_biases(1e-2)return self.modeldef get_data_loader(self, batch_size, is_distributed, no_aug=False):from yolox.data import (COCODataset,TrainTransform,YoloBatchSampler,DataLoader,InfiniteSampler,MosaicDetection,)dataset = COCODataset(data_dir=self.data_dir,json_file=self.train_ann,img_size=self.input_size,preproc=TrainTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),max_labels=50,),)dataset = MosaicDetection(dataset,mosaic=not no_aug,img_size=self.input_size,preproc=TrainTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),max_labels=120,),degrees=self.degrees,translate=self.translate,scale=self.scale,shear=self.shear,perspective=self.perspective,enable_mixup=self.enable_mixup,)self.dataset = datasetif is_distributed:batch_size = batch_size // dist.get_world_size()sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)batch_sampler = YoloBatchSampler(sampler=sampler,batch_size=batch_size,drop_last=False,input_dimension=self.input_size,mosaic=not no_aug,)dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}dataloader_kwargs["batch_sampler"] = batch_samplertrain_loader = DataLoader(self.dataset, **dataloader_kwargs)return train_loaderdef random_resize(self, data_loader, epoch, rank, is_distributed):tensor = torch.LongTensor(2).cuda()if rank == 0:size_factor = self.input_size[1] * 1. / self.input_size[0]size = random.randint(*self.random_size)size = (int(32 * size), 32 * int(size * size_factor))tensor[0] = size[0]tensor[1] = size[1]if is_distributed:dist.barrier()dist.broadcast(tensor, 0)input_size = data_loader.change_input_dim(multiple=(tensor[0].item(), tensor[1].item()), random_range=None)return input_sizedef get_optimizer(self, batch_size):if "optimizer" not in self.__dict__:if self.warmup_epochs > 0:lr = self.warmup_lrelse:lr = self.basic_lr_per_img * batch_sizepg0, pg1, pg2 = [], [], []  # optimizer parameter groupsfor k, v in self.model.named_modules():if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):pg2.append(v.bias)  # biasesif isinstance(v, nn.BatchNorm2d) or "bn" in k:pg0.append(v.weight)  # no decayelif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):pg1.append(v.weight)  # apply decayoptimizer = torch.optim.SGD(pg0, lr=lr, momentum=self.momentum, nesterov=True)optimizer.add_param_group({"params": pg1, "weight_decay": self.weight_decay})  # add pg1 with weight_decayoptimizer.add_param_group({"params": pg2})self.optimizer = optimizerreturn self.optimizerdef get_lr_scheduler(self, lr, iters_per_epoch):from yolox.utils import LRSchedulerscheduler = LRScheduler(self.scheduler,lr,iters_per_epoch,self.max_epoch,warmup_epochs=self.warmup_epochs,warmup_lr_start=self.warmup_lr,no_aug_epochs=self.no_aug_epochs,min_lr_ratio=self.min_lr_ratio,)return schedulerdef get_eval_loader(self, batch_size, is_distributed, testdev=False):from yolox.data import COCODataset, ValTransformvaldataset = COCODataset(data_dir=self.data_dir,json_file=self.val_ann if not testdev else "image_info_test-dev2017.json",name="val2017" if not testdev else "test2017",img_size=self.test_size,preproc=ValTransform(rgb_means=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),)if is_distributed:batch_size = batch_size // dist.get_world_size()sampler = torch.utils.data.distributed.DistributedSampler(valdataset, shuffle=False)else:sampler = torch.utils.data.SequentialSampler(valdataset)dataloader_kwargs = {"num_workers": self.data_num_workers,"pin_memory": True,"sampler": sampler,}dataloader_kwargs["batch_size"] = batch_sizeval_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)return val_loaderdef get_evaluator(self, batch_size, is_distributed, testdev=False):from yolox.evaluators import COCOEvaluatorval_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)evaluator = COCOEvaluator(dataloader=val_loader,img_size=self.test_size,confthre=self.test_conf,nmsthre=self.nmsthre,num_classes=self.num_classes,testdev=testdev,)return evaluatordef eval(self, model, evaluator, is_distributed, half=False):return evaluator.evaluate(model, is_distributed, half)

3.3 修改训练数据集路径
Yolox作者准备了VOC数据集文件,D:\YOLOX-main\exps\example\yolox_voc\yolox_voc_s.py,但是需要对其进行修改,这个文件需要修改的有:
(1)self.num_classes = 1
(2)get_data_loder 下的 VOCDetection 下的 image_sets=[(‘2007’, ‘trainval’), (‘2012’, ‘trainval’)],将其修改为image_sets=[(‘2007’, ‘train’)],;
*第一点修改根据个人需求决定,第二点我是根据自己使用习惯进行修改的,**如果准备的VOC数据集里面有trainval.txt文件,*那么可以不对源码进行修改。
==有的人在训练的过程中会出现AP=0的现象,这是因为函数get_eval_loader下的image_sets=[(‘2007’, ‘test’)]下test文件对应的test.txt里面内容为空,在第二步数据准备的时候一定要使得test.txt文件内容存在。==修改代码如下:

# encoding: utf-8
import osimport torch
import torch.distributed as distfrom yolox.data import get_yolox_datadir
from yolox.exp import Exp as MyExpclass Exp(MyExp):def __init__(self):super(Exp, self).__init__()self.num_classes = 1self.depth = 0.33self.width = 0.50self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]def get_data_loader(self, batch_size, is_distributed, no_aug=False):from yolox.data import (VOCDetection,TrainTransform,YoloBatchSampler,DataLoader,InfiniteSampler,MosaicDetection,)dataset = VOCDetection(data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),image_sets=[('2007', 'train')],img_size=self.input_size,preproc=TrainTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),max_labels=50,),)dataset = MosaicDetection(dataset,mosaic=not no_aug,img_size=self.input_size,preproc=TrainTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),max_labels=120,),degrees=self.degrees,translate=self.translate,scale=self.scale,shear=self.shear,perspective=self.perspective,enable_mixup=self.enable_mixup,)self.dataset = datasetif is_distributed:batch_size = batch_size // dist.get_world_size()sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)batch_sampler = YoloBatchSampler(sampler=sampler,batch_size=batch_size,drop_last=False,input_dimension=self.input_size,mosaic=not no_aug,)dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}dataloader_kwargs["batch_sampler"] = batch_samplertrain_loader = DataLoader(self.dataset, **dataloader_kwargs)return train_loaderdef get_eval_loader(self, batch_size, is_distributed, testdev=False):from yolox.data import VOCDetection, ValTransformvaldataset = VOCDetection(data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),image_sets=[('2007', 'test')],img_size=self.test_size,preproc=ValTransform(rgb_means=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225),),)if is_distributed:batch_size = batch_size // dist.get_world_size()sampler = torch.utils.data.distributed.DistributedSampler(valdataset, shuffle=False)else:sampler = torch.utils.data.SequentialSampler(valdataset)dataloader_kwargs = {"num_workers": self.data_num_workers,"pin_memory": True,"sampler": sampler,}dataloader_kwargs["batch_size"] = batch_sizeval_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)return val_loaderdef get_evaluator(self, batch_size, is_distributed, testdev=False):from yolox.evaluators import VOCEvaluatorval_loader = self.get_eval_loader(batch_size, is_distributed, testdev=testdev)evaluator = VOCEvaluator(dataloader=val_loader,img_size=self.test_size,confthre=self.test_conf,nmsthre=self.nmsthre,num_classes=self.num_classes,)return evaluator

3.4 修改Voc.py文件中Annotations文件的读取格式
修改源码 D:\YOLOX-main\yolox\data\datasets\voc.py文件下的_do_python_eval函数,将annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")修改为:annopath = os.path.join(rootpath, "Annotations", "{}.xml")

4、训练

有两种训练方式:
(1)直接在train.py文件下修改超参数;
(2)在终端训练,源码中的ReadMe有说明;
选用第一种方式训练,下面代码是我作的修改,在要修改的地方做了注释。

import argparse
import random
import warnings
from loguru import loggerimport torch
import torch.backends.cudnn as cudnnfrom yolox.core import Trainer, launch
from yolox.exp import get_exp
from yolox.utils import configure_nccl, configure_ompdef make_parser():parser = argparse.ArgumentParser("YOLOX train parser")parser.add_argument("-expn", "--experiment-name", type=str, default=None)parser.add_argument("-n", "--name", type=str, default='yolox-s', help="model name")#要使用那个模型,写那个模型的名称,yolox-s对应的yolox-s.model,还可以写成yolox-l等模型# distributedparser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend")parser.add_argument("--dist-url",default=None,type=str,help="url used to set up distributed training",)parser.add_argument("-b", "--batch-size", type=int, default=16, help="batch size")#根据电脑显卡的显存大小设置对应的batch-sizeparser.add_argument("-d", "--devices", default=0, type=int, help="device for training"#只有一张显卡,所以设置为0)parser.add_argument("-f","--exp_file",default='D:\\YOLOX-main\\exps\\example\\yolox_voc\\yolox_voc_s.py',#训练数据集绝对路径type=str,help="plz input your expriment description file",)parser.add_argument("--resume", default=False, action="store_true", help="resume training")parser.add_argument("-c", "--ckpt", default='D:\\YOLOX-main\\yolox_s.pth', type=str, help="checkpoint file")#预训练权重parser.add_argument("-e","--start_epoch",default=0,type=int,help="resume training start epoch",)parser.add_argument("--num_machines", default=1, type=int, help="num of node for training")parser.add_argument("--machine_rank", default=0, type=int, help="node rank for multi-node training")parser.add_argument("--fp16",dest="fp16",default=True,action="store_true",help="Adopting mix precision training.",)parser.add_argument("-o","--occupy",dest="occupy",default=False,action="store_true",help="occupy GPU memory first for training.",)parser.add_argument("opts",help="Modify config options using the command-line",default=None,nargs=argparse.REMAINDER,)return parser@logger.catch
def main(exp, args):if exp.seed is not None:random.seed(exp.seed)torch.manual_seed(exp.seed)cudnn.deterministic = Truewarnings.warn("You have chosen to seed training. This will turn on the CUDNN deterministic setting, ""which can slow down your training considerably! You may see unexpected behavior ""when restarting from checkpoints.")# set environment variables for distributed trainingconfigure_nccl()configure_omp()cudnn.benchmark = Truetrainer = Trainer(exp, args)trainer.train()if __name__ == "__main__":args = make_parser().parse_args()exp = get_exp(args.exp_file, args.name)exp.merge(args.opts)if not args.experiment_name:args.experiment_name = exp.exp_namenum_gpu = torch.cuda.device_count() if args.devices is None else args.devicesassert num_gpu <= torch.cuda.device_count()dist_url = "auto" if args.dist_url is None else args.dist_urllaunch(main,num_gpu,args.num_machines,args.machine_rank,backend=args.dist_backend,dist_url=dist_url,args=(exp, args),)

yolox-s预训练权重我放在百度网盘:链接: 提取码:0r71

5、使用自己的模型完成检测任务

使用demo.py文件来检测,其中要注意引用VOC数据集的类别。
修改1

from yolox.data.datasets import voc_classes#在训练的时候使用voc数据集的类别,所以在检测的时候要引入voc文件对应的类

修改2

 predictor = Predictor(model, exp, voc_classes.VOC_CLASSES,, trt_file, decoder, args.device)

修改3

cls_names=voc_classes.VOC_CLASSES,

修改后代码

import argparse
import os
import time
from loguru import loggerimport cv2import torchfrom yolox.data.data_augment import preproc
from yolox.data.datasets import voc_classes
from yolox.exp import get_exp
from yolox.utils import fuse_model, get_model_info, postprocess, visIMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]def make_parser():parser = argparse.ArgumentParser("YOLOX Demo!")parser.add_argument("--demo", default="image", help="demo type, eg. image, video and webcam")parser.add_argument("-expn", "--experiment-name", type=str, default=None)parser.add_argument("-n", "--name", type=str, default='yolox-s', help="model name")parser.add_argument("--path", default="要检测图像的路径", help="path to images or video")parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")parser.add_argument("--save_result",action="store_true",default=True,#如果要保存检测结果,需要设置默认值为Truehelp="whether to save the inference result of image/video",)# exp fileparser.add_argument("-f","--exp_file",default='D:\\YOLOX-main\\exps\\default\\yolox_s.py',type=str,help="pls input your expriment description file",)parser.add_argument("-c", "--ckpt", default='训练好模型的路径', type=str, help="ckpt for eval")parser.add_argument("--device",default="gpu",type=str,help="device to run our model, can either be cpu or gpu",)#------------------------------------------------------------------##根据自己的项目需求进行修改parser.add_argument("--conf", default=0.35, type=float, help="test conf")parser.add_argument("--nms", default=0.5, type=float, help="test nms threshold")parser.add_argument("--tsize", default=640, type=int, help="test img size")parser.add_argument("--fp16",dest="fp16",default=True,action="store_true",help="Adopting mix precision evaluating.",)parser.add_argument("--fuse",dest="fuse",default=False,action="store_true",help="Fuse conv and bn for testing.",)parser.add_argument("--trt",dest="trt",default=False,action="store_true",help="Using TensorRT model for testing.",)return parserdef get_image_list(path):image_names = []for maindir, subdir, file_name_list in os.walk(path):for filename in file_name_list:apath = os.path.join(maindir, filename)ext = os.path.splitext(apath)[1]if ext in IMAGE_EXT:image_names.append(apath)return image_namesclass Predictor(object):def __init__(self,model,exp,cls_names=voc_classes.VOC_CLASSES,trt_file=None,decoder=None,device="gpu",):self.model = modelself.cls_names = cls_namesself.decoder = decoderself.num_classes = exp.num_classesself.confthre = exp.test_confself.nmsthre = exp.nmsthreself.test_size = exp.test_sizeself.device = deviceif trt_file is not None:from torch2trt import TRTModulemodel_trt = TRTModule()model_trt.load_state_dict(torch.load(trt_file))x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()self.model(x)self.model = model_trtself.rgb_means = (0.485, 0.456, 0.406)self.std = (0.229, 0.224, 0.225)def inference(self, img):img_info = {"id": 0}if isinstance(img, str):img_info["file_name"] = os.path.basename(img)img = cv2.imread(img)else:img_info["file_name"] = Noneheight, width = img.shape[:2]img_info["height"] = heightimg_info["width"] = widthimg_info["raw_img"] = imgimg, ratio = preproc(img, self.test_size, self.rgb_means, self.std)img_info["ratio"] = ratioimg = torch.from_numpy(img).unsqueeze(0)if self.device == "gpu":img = img.cuda()with torch.no_grad():t0 = time.time()outputs = self.model(img)if self.decoder is not None:outputs = self.decoder(outputs, dtype=outputs.type())outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)logger.info("Infer time: {:.4f}s".format(time.time() - t0))return outputs, img_infodef visual(self, output, img_info, cls_conf=0.35):ratio = img_info["ratio"]img = img_info["raw_img"]if output is None:return imgoutput = output.cpu()bboxes = output[:, 0:4]# preprocessing: resizebboxes /= ratiocls = output[:, 6]scores = output[:, 4] * output[:, 5]vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)return vis_resdef image_demo(predictor, vis_folder, path, current_time, save_result):if os.path.isdir(path):files = get_image_list(path)else:files = [path]files.sort()for image_name in files:outputs, img_info = predictor.inference(image_name)result_image = predictor.visual(outputs[0], img_info, predictor.confthre)if save_result:save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))os.makedirs(save_folder, exist_ok=True)save_file_name = os.path.join(save_folder, os.path.basename(image_name))logger.info("Saving detection result in {}".format(save_file_name))cv2.imwrite(save_file_name, result_image)ch = cv2.waitKey(0)if ch == 27 or ch == ord("q") or ch == ord("Q"):breakdef imageflow_demo(predictor, vis_folder, current_time, args):cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # floatheight = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # floatfps = cap.get(cv2.CAP_PROP_FPS)save_folder = os.path.join(vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))os.makedirs(save_folder, exist_ok=True)if args.demo == "video":save_path = os.path.join(save_folder, args.path.split("/")[-1])else:save_path = os.path.join(save_folder, "camera.mp4")logger.info(f"video save_path is {save_path}")vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height)))while True:ret_val, frame = cap.read()if ret_val:outputs, img_info = predictor.inference(frame)result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)if args.save_result:vid_writer.write(result_frame)ch = cv2.waitKey(1)if ch == 27 or ch == ord("q") or ch == ord("Q"):breakelse:breakdef main(exp, args):if not args.experiment_name:args.experiment_name = exp.exp_namefile_name = os.path.join(exp.output_dir, args.experiment_name)os.makedirs(file_name, exist_ok=True)if args.save_result:vis_folder = os.path.join(file_name, "vis_res")os.makedirs(vis_folder, exist_ok=True)if args.trt:args.device = "gpu"logger.info("Args: {}".format(args))if args.conf is not None:exp.test_conf = args.confif args.nms is not None:exp.nmsthre = args.nmsif args.tsize is not None:exp.test_size = (args.tsize, args.tsize)model = exp.get_model()logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))if args.device == "gpu":model.cuda()model.eval()if not args.trt:if args.ckpt is None:ckpt_file = os.path.join(file_name, "best_ckpt.pth")else:ckpt_file = args.ckptlogger.info("loading checkpoint")ckpt = torch.load(ckpt_file, map_location="cpu")# load the model state dictmodel.load_state_dict(ckpt["model"])logger.info("loaded checkpoint done.")if args.fuse:logger.info("\tFusing model...")model = fuse_model(model)if args.trt:assert not args.fuse, "TensorRT model is not support model fusing!"trt_file = os.path.join(file_name, "model_trt.pth")assert os.path.exists(trt_file), "TensorRT model is not found!\n Run python3 tools/trt.py first!"model.head.decode_in_inference = Falsedecoder = model.head.decode_outputslogger.info("Using TensorRT to inference")else:trt_file = Nonedecoder = Nonepredictor = Predictor(model, exp, voc_classes.VOC_CLASSES, trt_file, decoder, args.device)current_time = time.localtime()if args.demo == "image":image_demo(predictor, vis_folder, args.path, current_time, args.save_result)elif args.demo == "video" or args.demo == "webcam":imageflow_demo(predictor, vis_folder, current_time, args)if __name__ == "__main__":args = make_parser().parse_args()exp = get_exp(args.exp_file, args.name)main(exp, args)

检测结果

如有错误,大家私聊我;如有不理解的地方,可以在评论区提问,修改后的代码之后会上传到Github。

更多推荐

Yolox 训练自己的数据集

本文发布于:2024-02-14 15:43:47,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1764157.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:自己的   数据   Yolox

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!