openpose训练代码(二)

编程入门 行业动态 更新时间:2024-10-10 21:28:09

openpose训练<a href=https://www.elefans.com/category/jswz/34/1771412.html style=代码(二)"/>

openpose训练代码(二)

openpose训练代码(一):
openpose训练代码(二):


在上一篇openpose训练代码(一) 中讲到cpm_data_transformer,其实这个文件才是包含数据处理核心代码的文件,在上一篇博客提高Transform_nv函数,我们先来看看Transform_nv函数:

template<typename Dtype> void CPMDataTransformer<Dtype>::Transform_nv(const Datum& datum, Blob<Dtype>* transformed_data, Blob<Dtype>* transformed_label, int cnt) {//std::cout << "Function 2 is used"; std::cout.flush();const int datum_channels = datum.channels();//const int datum_height = datum.height();//const int datum_width = datum.width();const int im_channels = transformed_data->channels();//const int im_height = transformed_data->height();//const int im_width = transformed_data->width();const int im_num = transformed_data->num();//const int lb_channels = transformed_label->channels();//const int lb_height = transformed_label->height();//const int lb_width = transformed_label->width();const int lb_num = transformed_label->num();//LOG(INFO) << "image shape: " << transformed_data->num() << " " << transformed_data->channels() << " " //                             << transformed_data->height() << " " << transformed_data->width();//LOG(INFO) << "label shape: " << transformed_label->num() << " " << transformed_label->channels() << " " //                             << transformed_label->height() << " " << transformed_label->width();CHECK_EQ(datum_channels, 6);CHECK_EQ(im_channels, 6);///CHECK_EQ(im_channels, 4);//CHECK_EQ(datum_channels, 4);CHECK_EQ(im_num, lb_num);//CHECK_LE(im_height, datum_height);//CHECK_LE(im_width, datum_width);CHECK_GE(im_num, 1);//const int crop_size = param_.crop_size();// if (crop_size) {//   CHECK_EQ(crop_size, im_height);//   CHECK_EQ(crop_size, im_width);// } else {//   CHECK_EQ(datum_height, im_height);//   CHECK_EQ(datum_width, im_width);// }Dtype* transformed_data_pointer = transformed_data->mutable_cpu_data();Dtype* transformed_label_pointer = transformed_label->mutable_cpu_data();CPUTimer timer;timer.Start();Transform_nv(datum, transformed_data_pointer, transformed_label_pointer, cnt); //call function 1VLOG(2) << "Transform_nv: " << timer.MicroSeconds() / 1000.0  << " ms";
}

这个函数主要就是得到lmdb的一些参数,比如datum_channels,im_channels 等,转而调用Transform_nv函数

template<typename Dtype> void CPMDataTransformer<Dtype>::Transform_nv(const Datum& datum, Dtype* transformed_data, Dtype* transformed_label, int cnt) {...
}

data是lmdb的首地址,datum_channels,datum_height ,datum_width 分别是之前python代码确定的每页的尺寸,mask_miss 和mask_all全1的矩阵,为后续所用做准备。

  const string& data = datum.data();const int datum_channels = datum.channels();const int datum_height = datum.height();const int datum_width = datum.width();// To do: make this a parameter in caffe.proto//const int mode = 5; //related to datum.channels();const int mode = 5;//const int crop_size = param_.crop_size();//const Dtype scale = param_.scale();//const bool do_mirror = param_.mirror() && Rand(2);//const bool has_mean_file = param_.has_mean_file();const bool has_uint8 = data.size() > 0;//const bool has_mean_values = mean_values_.size() > 0;int crop_x = param_.crop_size_x();int crop_y = param_.crop_size_y();CHECK_GT(datum_channels, 0);//CHECK_GE(datum_height, crop_size);//CHECK_GE(datum_width, crop_size);CPUTimer timer1;timer1.Start();//before any transformation, get the image from datumMat img = Mat::zeros(datum_height, datum_width, CV_8UC3);Mat mask_all, mask_miss;if(mode >= 5){mask_miss = Mat::ones(datum_height, datum_width, CV_8UC1);}if(mode == 6){mask_all = Mat::zeros(datum_height, datum_width, CV_8UC1);}

读取原始图片数据保存在rbg中,以及读取mask_miss 和 mask_all,如下:
offset = img.rows * img.cols,为指针偏移量,和python文件一一对应。

  int offset = img.rows * img.cols;int dindex;Dtype d_element;for (int i = 0; i < img.rows; ++i) {for (int j = 0; j < img.cols; ++j) {Vec3b& rgb = img.at<Vec3b>(i, j);for(int c = 0; c < 3; c++){dindex = c*offset + i*img.cols + j;if (has_uint8)d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));elsed_element = datum.float_data(dindex);rgb[c] = d_element;}if(mode >= 5){dindex = 4*offset + i*img.cols + j;if (has_uint8)d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));elsed_element = datum.float_data(dindex);if (round(d_element/255)!=1 && round(d_element/255)!=0){cout << d_element << " " << round(d_element/255) << endl;}mask_miss.at<uchar>(i, j) = d_element; //round(d_element/255);}if(mode == 6){dindex = 5*offset + i*img.cols + j;if (has_uint8)d_element = static_cast<Dtype>(static_cast<uint8_t>(data[dindex]));elsed_element = datum.float_data(dindex);mask_all.at<uchar>(i, j) = d_element;}}}VLOG(2) << "  rgb[:] = datum: " << timer1.MicroSeconds()/1000.0 << " ms";timer1.Start();

接下来开始读meta文件,就是存储的关键点和尺寸信息,其中关键的是ReadMetaData函数,这个函数就是完全按照python写入格式来读的,所以,一定要理清楚python代码的逻辑,不然,这里很容易混乱,同时,这里有一个小的技巧就是转换了关键点的顺序,TransformMetaJoints函数实现这一功能,其实就是为了和MPII数据集对应,我的理解是方便transfer 权重,代码如下:

  //color, contractif(param_.do_clahe())clahe(img, clahe_tileSize, clahe_clipLimit);if(param_.gray() == 1){cv::cvtColor(img, img, CV_BGR2GRAY);cv::cvtColor(img, img, CV_GRAY2BGR);}VLOG(2) << "  color: " << timer1.MicroSeconds()/1000.0 << " ms";timer1.Start();int offset3 = 3 * offset;int offset1 = datum_width;int stride = param_.stride();ReadMetaData(meta, data, offset3, offset1);if(param_.transform_body_joint()) // we expect to transform body joints, and not to transform hand jointsTransformMetaJoints(meta);VLOG(2) << "  ReadMeta+MetaJoints: " << timer1.MicroSeconds()/1000.0 << " ms";

读取到原始数据后,接下来做的就是数据增广,原始代码主要做了如下几种数据增广:scale、rotate、crop、flip;具体实现如下,没做一个都是叠加在原来的基础上,这里在做数据增广的时候,用到了原图scale的信息:

  //Start transformingMat img_aug = Mat::zeros(crop_y, crop_x, CV_8UC3);Mat mask_miss_aug, mask_all_aug ;//Mat mask_miss_aug = Mat::zeros(crop_y, crop_x, CV_8UC1);//Mat mask_all_aug = Mat::zeros(crop_y, crop_x, CV_8UC1);Mat img_temp, img_temp2, img_temp3; //size determined by scaleVLOG(2) << "   input size (" << img.cols << ", " << img.rows << ")"; // We only do random transform as augmentation when training.if (phase_ == TRAIN) {as.scale = augmentation_scale(img, img_temp, mask_miss, mask_all, meta, mode);//LOG(INFO) << meta.joint_self.joints.size();//LOG(INFO) << meta.joint_self.joints[0];as.degree = augmentation_rotate(img_temp, img_temp2, mask_miss, mask_all, meta, mode);//LOG(INFO) << meta.joint_self.joints.size();//LOG(INFO) << meta.joint_self.joints[0];if(0 && param_.visualize()) visualize(img_temp2, meta, as);as.crop = augmentation_croppad(img_temp2, img_temp3, mask_miss, mask_miss_aug, mask_all, mask_all_aug, meta, mode);//LOG(INFO) << meta.joint_self.joints.size();//LOG(INFO) << meta.joint_self.joints[0];if(0 && param_.visualize()) visualize(img_temp3, meta, as);as.flip = augmentation_flip(img_temp3, img_aug, mask_miss_aug, mask_all_aug, meta, mode);//LOG(INFO) << meta.joint_self.joints.size();//LOG(INFO) << meta.joint_self.joints[0];if(param_.visualize()) visualize(img_aug, meta, as);// imshow("img_aug", img_aug);// Mat label_map = mask_miss_aug;// applyColorMap(label_map, label_map, COLORMAP_JET);// addWeighted(label_map, 0.5, img_aug, 0.5, 0.0, label_map);// imshow("mask_miss_aug", label_map);if (mode > 4){resize(mask_miss_aug, mask_miss_aug, Size(), 1.0/stride, 1.0/stride, INTER_CUBIC);}if (mode > 5){resize(mask_all_aug, mask_all_aug, Size(), 1.0/stride, 1.0/stride, INTER_CUBIC);}}else {img_aug = img.clone();as.scale = 1;as.crop = Size();as.flip = 0;as.degree = 0;}VLOG(2) << "  Aug: " << timer1.MicroSeconds()/1000.0 << " ms";timer1.Start();

数据增广过后就是归一化,和准备label文件,有一点不同的地方就是负责背景关键点的那一个label使用的是mask_miss信息,同时,把输入归一化到 [-0.5, 0.5] 具体如下:

  for (int i = 0; i < img_aug.rows; ++i) {for (int j = 0; j < img_aug.cols; ++j) {Vec3b& rgb = img_aug.at<Vec3b>(i, j);transformed_data[0*offset + i*img_aug.cols + j] = (rgb[0] - 128)/256.0;transformed_data[1*offset + i*img_aug.cols + j] = (rgb[1] - 128)/256.0;transformed_data[2*offset + i*img_aug.cols + j] = (rgb[2] - 128)/256.0;}}// label size is image size/ strideif (mode > 4){for (int g_y = 0; g_y < grid_y; g_y++){for (int g_x = 0; g_x < grid_x; g_x++){for (int i = 0; i < np; i++){float weight = float(mask_miss_aug.at<uchar>(g_y, g_x)) /255; //mask_miss_aug.at<uchar>(i, j); if (meta.joint_self.isVisible[i] != 3){transformed_label[i*channelOffset + g_y*grid_x + g_x] = weight;}}  // background channelif(mode == 5){transformed_label[np*channelOffset + g_y*grid_x + g_x] = float(mask_miss_aug.at<uchar>(g_y, g_x)) /255;}if(mode > 5){transformed_label[np*channelOffset + g_y*grid_x + g_x] = 1;transformed_label[(2*np+1)*channelOffset + g_y*grid_x + g_x] = float(mask_all_aug.at<uchar>(g_y, g_x)) /255;}}}}  

做完上面的工作,把图片数据准备好,背景关键点准备好,就剩下其它关键点和PAF的label了,主要是在generateLabelMap函数中完成。

  //putGaussianMaps(transformed_data + 3*offset, meta.objpos, 1, img_aug.cols, img_aug.rows, param_.sigma_center());//LOG(INFO) << "image transformation done!";generateLabelMap(transformed_label, img_aug, meta);VLOG(2) << "  putGauss+genLabel: " << timer1.MicroSeconds()/1000.0 << " ms";//starts to visualize everything (transformed_data in 4 ch, label) fed into conv1//if(param_.visualize()){//dumpEverything(transformed_data, transformed_label, meta);//}

具体的,我们来看一下generateLabelMap函数,大概的说来,主要就是做两件事,其一是在每个关键点部位放置高斯响应,其二就是在有连接的关键点之间放vector,更具体的细节,可以去查阅源代码,这里不再做更为详细的说明:

template<typename Dtype>
void CPMDataTransformer<Dtype>::generateLabelMap(Dtype* transformed_label, Mat& img_aug, MetaData meta) {int rezX = img_aug.cols;int rezY = img_aug.rows;int stride = param_.stride();int grid_x = rezX / stride;int grid_y = rezY / stride;int channelOffset = grid_y * grid_x;int mode = 5; // TO DO: make this as a parameterfor (int g_y = 0; g_y < grid_y; g_y++){for (int g_x = 0; g_x < grid_x; g_x++){for (int i = np+1; i < 2*(np+1); i++){if (mode == 6 && i == (2*np + 1))continue;transformed_label[i*channelOffset + g_y*grid_x + g_x] = 0;}}}if (np == 56){for (int i = 0; i < 18; i++){Point2f center = meta.joint_self.joints[i];if(meta.joint_self.isVisible[i] <= 1){putGaussianMaps(transformed_label + (i+np+39)*channelOffset, center, param_.stride(), grid_x, grid_y, param_.sigma()); //self}for(int j = 0; j < meta.numOtherPeople; j++){ //for every other personPoint2f center = meta.joint_others[j].joints[i];if(meta.joint_others[j].isVisible[i] <= 1){putGaussianMaps(transformed_label + (i+np+39)*channelOffset, center, param_.stride(), grid_x, grid_y, param_.sigma());}}}int mid_1[19] = {2, 9,  10, 2,  12, 13, 2, 3, 4, 3,  2, 6, 7, 6,  2, 1,  1,  15, 16};int mid_2[19] = {9, 10, 11, 12, 13, 14, 3, 4, 5, 17, 6, 7, 8, 18, 1, 15, 16, 17, 18};int thre = 1;for(int i=0;i<19;i++){Mat count = Mat::zeros(grid_y, grid_x, CV_8UC1);Joints jo = meta.joint_self;if(jo.isVisible[mid_1[i]-1]<=1 && jo.isVisible[mid_2[i]-1]<=1){//putVecPeaksputVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, count, jo.joints[mid_1[i]-1], jo.joints[mid_2[i]-1], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self}for(int j = 0; j < meta.numOtherPeople; j++){ //for every other personJoints jo2 = meta.joint_others[j];if(jo2.isVisible[mid_1[i]-1]<=1 && jo2.isVisible[mid_2[i]-1]<=1){//putVecPeaksputVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, count, jo2.joints[mid_1[i]-1], jo2.joints[mid_2[i]-1], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self}}}//put background channelfor (int g_y = 0; g_y < grid_y; g_y++){for (int g_x = 0; g_x < grid_x; g_x++){float maximum = 0;//second background channelfor (int i = np+39; i < np+57; i++){maximum = (maximum > transformed_label[i*channelOffset + g_y*grid_x + g_x]) ? maximum : transformed_label[i*channelOffset + g_y*grid_x + g_x];}transformed_label[(2*np+1)*channelOffset + g_y*grid_x + g_x] = max(1.0-maximum, 0.0);}}//LOG(INFO) << "background put";}else if (np == 43){for (int i = 0; i < 15; i++){Point2f center = meta.joint_self.joints[i];if(meta.joint_self.isVisible[i] <= 1){putGaussianMaps(transformed_label + (i+np+29)*channelOffset, center, param_.stride(), grid_x, grid_y, param_.sigma()); //self}for(int j = 0; j < meta.numOtherPeople; j++){ //for every other personPoint2f center = meta.joint_others[j].joints[i];if(meta.joint_others[j].isVisible[i] <= 1){putGaussianMaps(transformed_label + (i+np+29)*channelOffset, center, param_.stride(), grid_x, grid_y, param_.sigma());}}}int mid_1[14] = {0, 1, 2, 3, 1, 5, 6, 1, 14, 8, 9,  14, 11, 12};int mid_2[14] = {1, 2, 3, 4, 5, 6, 7, 14, 8, 9, 10, 11, 12, 13};int thre = 1;for(int i=0;i<14;i++){Mat count = Mat::zeros(grid_y, grid_x, CV_8UC1);Joints jo = meta.joint_self;if(jo.isVisible[mid_1[i]]<=1 && jo.isVisible[mid_2[i]]<=1){//putVecPeaksputVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, count, jo.joints[mid_1[i]], jo.joints[mid_2[i]], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self}for(int j = 0; j < meta.numOtherPeople; j++){ //for every other personJoints jo2 = meta.joint_others[j];if(jo2.isVisible[mid_1[i]]<=1 && jo2.isVisible[mid_2[i]]<=1){//putVecPeaksputVecMaps(transformed_label + (np+ 1+ 2*i)*channelOffset, transformed_label + (np+ 2+ 2*i)*channelOffset, count, jo2.joints[mid_1[i]], jo2.joints[mid_2[i]], param_.stride(), grid_x, grid_y, param_.sigma(), thre); //self}}}//put background channelfor (int g_y = 0; g_y < grid_y; g_y++){for (int g_x = 0; g_x < grid_x; g_x++){float maximum = 0;//second background channelfor (int i = np+29; i < np+44; i++){maximum = (maximum > transformed_label[i*channelOffset + g_y*grid_x + g_x]) ? maximum : transformed_label[i*channelOffset + g_y*grid_x + g_x];}transformed_label[(2*np+1)*channelOffset + g_y*grid_x + g_x] = max(1.0-maximum, 0.0);}}//LOG(INFO) << "background put";}//visualizeif(1 && param_.visualize()){Mat label_map;for(int i = 0; i < 2*(np+1); i++){      label_map = Mat::zeros(grid_y, grid_x, CV_8UC1);//int MPI_index = MPI_to_ours[i];//Point2f center = meta.joint_self.joints[MPI_index];for (int g_y = 0; g_y < grid_y; g_y++){//printf("\n");for (int g_x = 0; g_x < grid_x; g_x++){label_map.at<uchar>(g_y,g_x) = (int)(transformed_label[i*channelOffset + g_y*grid_x + g_x]*255);//printf("%f ", transformed_label_entry[g_y*grid_x + g_x]*255);}}resize(label_map, label_map, Size(), stride, stride, INTER_LINEAR);applyColorMap(label_map, label_map, COLORMAP_JET);addWeighted(label_map, 0.5, img_aug, 0.5, 0.0, label_map);//center = center * (1.0/(float)param_.stride());//circle(label_map, center, 3, CV_RGB(255,0,255), -1);char imagename [100];sprintf(imagename, "augment_%04d_label_part_%02d.jpg", meta.write_number, i);//LOG(INFO) << "filename is " << imagename;imwrite(imagename, label_map);}}
}

原文链接:

更多推荐

openpose训练代码(二)

本文发布于:2024-02-06 18:54:19,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1750975.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:代码   openpose

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!