pytorch采用resnet实现猫狗分类

编程入门 行业动态 更新时间:2024-10-09 11:27:08

pytorch采用resnet实现<a href=https://www.elefans.com/category/jswz/34/1769144.html style=猫狗分类"/>

pytorch采用resnet实现猫狗分类

最近在练习pytorch使用.

首先下载猫狗数据:

链接:
提取码:2xq4

然后写代码,感觉这种逐批次从硬盘取数据训练有点慢,但先跑起来吧.感兴趣的可以去看看resnet源码,最好自己手敲一遍,练习效果更好.源码如下:resnet代码分析 - 慢行厚积 - 博客园   先熟悉简单的pytorch接口,后面来搞高级的检测和分割.

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset
from torchvision import transforms,datasets,models
import shutilrandom_state = 42
np.random.seed(random_state)original_dataset_dir = '/home/wangyunhao/dc/dogs-vs-cats/train/train'
total_num = int(len(os.listdir(original_dataset_dir))/2)
random_idx = np.array(range(total_num))
np.random.shuffle(random_idx)base_dir = '/home/wangyunhao/dc/dog_cat_deal'
if not os.path.exists(base_dir):os.mkdir(base_dir)sub_dirs = ['train','test']
animals = ['cats','dogs']
train_idx = random_idx[:int(total_num*0.9)]
test_idx = random_idx[int(total_num*0.9):]
numbers = [train_idx,test_idx]
for idx,sub_dir in enumerate(sub_dirs):dir = os.path.join(base_dir,sub_dir)if not os.path.exists(dir):os.mkdir(dir)for animal in animals:animal_dir = os.path.join(dir,animal)if not os.path.exists(animal_dir):os.mkdir(animal_dir)fnames = [animal[:-1] + '.{}.jpg'.format(i) for i in numbers[idx]]for fname in fnames:src = os.path.join(original_dataset_dir,fname)dst = os.path.join(animal_dir,fname)shutil.copyfile(src,dst)print(animal_dir+ '  total images : %d ' %(len(os.listdir(animal_dir))))random_state = 1torch.manual_seed(random_state)torch.cuda.manual_seed(random_state)torch.cuda.manual_seed_all(random_state)np.random.seed(random_state)epochs = 10batch_size = 10num_workers = 0use_gpu = torch.cuda.is_available()model_path = '/home/wangyunhao/dc/dc_dog.pt'data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std = [0.229, 0.224, 0.225])])train_dataset = datasets.ImageFolder(root = '/home/wangyunhao/dc/dog_cat_deal/train/',transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_dataset = datasets.ImageFolder(root='/home/wangyunhao/dc/dog_cat_deal/test', transform=data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)net = models.resnet101(num_classes=2)
if(os.path.exists('/home/wangyunhao/dc/dc_dog.pt')):net = torch.load('/home/wangyunhao/dc/dc_dog.pt')if use_gpu:net = net.cuda()
print(net)criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr=0.0001, momentum=0.9)def train():for epoch in range(epochs):running_loss = 0.0train_correct = 0train_total = 0for i,data in enumerate(train_loader,0):inputs,train_labels = dataprint(i,train_labels)if use_gpu:inputs,labels = Variable(inputs.cuda()),Variable(train_labels.cuda())else:inputs,labels = Variable(inputs), Variable(train_labels)optimizer.zero_grad()outputs = net(inputs)_,train_predicted = torch.max(outputs.data,1)train_correct += (train_predicted==labels.data).sum()loss = criterion(outputs,labels)loss.backward()optimizer.step()running_loss += loss.item()train_total += train_labels.size(0)print('train %d epoch loss: %.3f  acc: %.3f ' %(epoch+1,running_loss/train_total,100*train_correct / train_total))correct = 0test_loss = 0.0test_total = 0net.eval()for data in test_loader:images,labels = dataif use_gpu:images,labels = Variable(images.cuda()), Variable(labels.cuda())else:images, labels = Variable(images), Variable(labels)outputs = net(images)_,predicted = torch.max(outputs.data,1)loss = criterion(outputs,labels)test_loss += loss.item()test_total += labels.size(0)correct += (predicted == labels.data).sum()print('test  %d epoch loss: %.3f  acc: %.3f' % (epoch+1,test_loss/test_total,100*correct/test_total))torch.save(net,'/home/wangyunhao/dc/dc_dog.pt')train()

更多推荐

pytorch采用resnet实现猫狗分类

本文发布于:2024-02-28 07:56:20,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1768742.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:猫狗   pytorch   resnet

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!