OpenCV Load caffe model(转)"/>
【caffe】OpenCV Load caffe model(转)
原文地址:
上一篇,我们介绍了opencv_contrib中的模块在windows下的编译,也提到了其中的dnn模块可以读取caffe的训练模型用于目标检测,这里我们具体介绍一下如何使用dnn读取caffe模型并进行目标分类。
代码如下:(代码主要来自参考[2]和[3]):
-
#include <opencv2/dnn.hpp>
-
#include <opencv2/imgproc.hpp>
-
#include <opencv2/highgui.hpp>
-
#include <fstream>
-
#include <iostream>
-
#include <cstdlib>
-
/* Find best class for the blob (i. e. class with maximal probability) */
-
void getMaxClass(cv::dnn::Blob &probBlob, int *classId, double *classProb)
-
{
-
cv::Mat probMat = probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix
-
cv::Point classNumber;
-
cv::minMaxLoc(probMat, NULL, classProb, NULL, &classNumber);
-
*classId = classNumber.x;
-
}
-
std::vector<cv::String> readClassNames(const char *filename = "synset_words.txt")
-
{
-
std::vector<cv::String> classNames;
-
std::ifstream fp(filename);
-
if (!fp.is_open())
-
{
-
std::cerr << "File with classes labels not found: " << filename << std::endl;
-
exit(-1);
-
}
-
std::string name;
-
while (!fp.eof())
-
{
-
std::getline(fp, name);
-
if (name.length())
-
classNames.push_back(name.substr(name.find(' ') + 1));
-
}
-
fp.close();
-
return classNames;
-
}
-
int main(int argc, char **argv)
-
{
-
void cv::dnn::initModule();
-
cv::String modelTxt = "bvlc_googlenet.prototxt";
-
cv::String modelBin = "bvlc_googlenet.caffemodel";
-
cv::String imageFile = "space_shuttle.jpg";
-
cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);
-
if (net.empty())
-
{
-
std::cerr << "Can't load network by using the following files: " << std::endl;
-
std::cerr << "prototxt: " << modelTxt << std::endl;
-
std::cerr << "caffemodel: " << modelBin << std::endl;
-
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
-
std::cerr << ".caffemodel" << std::endl;
-
exit(-1);
-
}
-
//! [Prepare blob]
-
cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);
-
if (img.empty())
-
{
-
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
-
exit(-1);
-
}
-
cv::resize(img, img, cv::Size(224, 224));
-
cv::dnn::Blob inputBlob = cv::dnn::Blob(img); //Convert Mat to dnn::Blob image batch
-
//! [Prepare blob]
-
//! [Set input blob]
-
net.setBlob(".data", inputBlob); //set the network input
-
//! [Set input blob]
-
//! [Make forward pass]
-
net.forward(); //compute output
-
//! [Make forward pass]
-
//! [Gather output]
-
cv::dnn::Blob prob = net.getBlob("prob"); //gather output of "prob" layer
-
int classId;
-
double classProb;
-
getMaxClass(prob, &classId, &classProb);//find the best class
-
//! [Gather output]
-
//! [Print results]
-
std::vector<cv::String> classNames = readClassNames();
-
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
-
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
-
//! [Print results]
-
return 0;
-
} //main
代码详解:
1、首先需要下载GoogLeNet模型及分类相关文件,可以从官网下载(或复制粘贴): bvlc_googlenet.prototxt、bvlc_googlenet.caffemodel以及synset_words.txt.也可以直接下载我长传的打包好的资源(包括了2中的图片)
2、下载待检测图片文件,如下:
Buran space shuttle
3、读取.protxt文件和.caffemodel文件:
cv::dnn::Net net = cv::dnn::readNetFromCaffe(modelTxt, modelBin);
4、检查网络是否读取成功:
-
if (net.empty())
-
{
-
std::cerr << "Can't load network by using the following files: " << std::endl;
-
std::cerr << "prototxt: " << modelTxt << std::endl;
-
std::cerr << "caffemodel: " << modelBin << std::endl;
-
std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
-
std::cerr << ".caffemodel" << std::endl;
-
exit(-1);
-
}
5、读取图片并将其转换成GoogleNet可以读取的blob:
-
cv::Mat img = cv::imread(imageFile, cv::IMREAD_COLOR);
-
if (img.empty())
-
{
-
std::cerr << "Can't read image from the file: " << imageFile << std::endl;
-
exit(-1);
-
}
-
cv::resize(img, img, cv::Size(224, 224));
-
cv::dnn::Blob inputBlob = cv::dnn::Blob(img); //Convert Mat to dnn::Blob image batch
6、将blob传递给网络:
net.setBlob(".data", inputBlob); //set the network input
7、前向传递:
net.forward(); //compute output
8、分类:
getMaxClass(prob, &classId, &classProb);//find the best class
9、打印分类结果:
-
std::vector<cv::String> classNames = readClassNames();
-
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
-
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
运行,报错如下:
找了很久,终于在参考[3]中找到了解决方案,原因是这里将图像数据转换成blob的方法来自于老版本,在新版本中不兼容。解决方法如下:将cv::dnn::Blob(img) 用cv::dnn::Blob::fromImages(img)替换掉。
修改后,再运行,结果如下:
参考:
[1] .html
[2]
[3]
-----------------------------------------
2017.07.24
更多推荐
【caffe】OpenCV Load caffe model(转)
发布评论