ICDAR2015 数据处理及训练

编程入门 行业动态 更新时间:2024-10-25 11:29:42

ICDAR2015 <a href=https://www.elefans.com/category/jswz/34/1768995.html style=数据处理及训练"/>

ICDAR2015 数据处理及训练

训练数据处理:

天池ICPR2018和MSRA_TD500两个数据集:

1)天池ICPR的数据集为网络图像,都是一些淘宝商家上传到淘宝的一些商品介绍图像,其标签方式参考了ICDAR2015的数据标签格式,即一个文本框用4个坐标来表示,即左上、右上、右下、左下四个坐标,共八个值,记作[x1 y1 x2 y2 x3 y3 x4 y4]

 2)MSRA_TD500使微软收集的一个文本检测和识别的一个数据集,里面的图像多是街景图,背景比较复杂,但文本位置比较明显,一目了然。

因为MSRA_TD500的标签格式不一样,最后一个参数表示矩形框的旋转角度。

所以我们第一步就是将这两个数据集的标签格式统一,我的做法是将MSRA数据集格式改为ICDAR格式,方便后面的模型训练。

因为MSRA_TD500采取的标签格式是[index difficulty_label x y w h angle],所以我们需要根据这个文本框的旋转角度来求得水平文本框旋转后的4个坐标位置。实现如下:

"""
This file is to change MSRA_TD500 dataset format to ICDAR2015 dataset format.MSRA_TD500 format: [index difficulty_label x y w h angle]ICDAR2015 format: [left_top_x left_top_y right_top_X right_top_y right_bottom_x right_bottom_y left_bottom_x left_bottom_y]"""import math
import cv2
import os# 求旋转后矩形的4个坐标
def get_box_img(x, y, w, h, angle):# 矩形框中点(x0,y0)x0 = x + w/2y0 = y + h/2l = math.sqrt(pow(w/2, 2) + pow(h/2, 2))  # 即对角线的一半# angle小于0,逆时针转if angle < 0:a1 = -angle + math.atan(h / float(w))  # 旋转角度-对角线与底线所成的角度a2 = -angle - math.atan(h / float(w)) # 旋转角度+对角线与底线所成的角度pt1 = (x0 - l * math.cos(a2), y0 + l * math.sin(a2))pt2 = (x0 + l * math.cos(a1), y0 - l * math.sin(a1))pt3 = (x0 + l * math.cos(a2), y0 - l * math.sin(a2))  # x0+左下点旋转后在水平线上的投影, y0-左下点在垂直线上的投影,显然逆时针转时,左下点上一和左移了。pt4 = (x0 - l * math.cos(a1), y0 + l * math.sin(a1))else:a1 = angle + math.atan(h / float(w))a2 = angle - math.atan(h / float(w))pt1 = (x0 - l * math.cos(a1), y0 - l * math.sin(a1))pt2 = (x0 + l * math.cos(a2), y0 + l * math.sin(a2))pt3 = (x0 + l * math.cos(a1), y0 + l * math.sin(a1))pt4 = (x0 - l * math.cos(a2), y0 - l * math.sin(a2))return [pt1[0], pt1[1], pt2[0], pt2[1], pt3[0], pt3[1], pt4[0], pt4[1]]def read_file(path):result = []for line in open(path):info = []data = line.split(' ')info.append(int(data[2]))info.append(int(data[3]))info.append(int(data[4]))info.append(int(data[5]))info.append(float(data[6]))info.append(data[0])result.append(info)return resultif __name__ == '__main__':file_path = '/home/ljs/OCR_dataset/MSRA-TD500/test/'save_img_path = '../dataset/OCR_dataset/ctpn/test_im/'save_gt_path = '../dataset/OCR_dataset/ctpn/test_gt/'file_list = os.listdir(file_path)for f in file_list:if '.gt' in f:continuename = f[0:8]txt_path = file_path + name + '.gt'im_path = file_path + fim = cv2.imread(im_path)coordinate = read_file(txt_path)# 仿照ICDAR格式,图片名字写做img_xx.jpg,对应的标签文件写做gt_img_xx.txtcv2.imwrite(save_img_path + name.lower() + '.jpg', im)save_gt = open(save_gt_path + 'gt_' + name.lower() + '.txt', 'w')for i in coordinate:box = get_box_img(i[0], i[1], i[2], i[3], i[4])box = [int(box[i]) for i in range(len(box))]box = [str(box[i]) for i in range(len(box))]save_gt.write(','.join(box))save_gt.write('\n')

经过格式处理后,我们两份数据集算是整理好了。当然我们还需要对整个数据集划分为训练集和测试集,我的文件组织习惯如下:

train_im, test_im文件夹装的是训练和测试图像,train_gt和test_gt装的是训练和测试标签。

训练标签生成

因为CTPN的核心思想也是基于Faster RCNN中的region proposal机制的,所以原始数据标签需要转化为 anchor标签。训练数据的标签的生成的代码是最难写,

因为从一个完整的文本框标签转化为一个个小尺度文本框标签确实有点难度,而且这个anchor标签的生成方式也与Faster RCNN生成方式略有不同。下面讲一讲我的实现思路:

第一步我们需要将原先每张图的bbox标签转化为每个anchor标签。为了实现该功能,我们先将一张图划分为宽度为16的各个anchor。

  • 首先计算一张图可以分为多少个宽度为16的acnhor(比如一张图的宽度为w,那么水平anchor总数为w/16),再计算出我们的文本框标签中含有几个acnhor,最左和最右的anchor又是哪几个;
  • 计算文本框内anchor的高度和中心是多少:此时我们可以在一个全黑的mask中把文本框label画上去(白色),然后从上往下和从下往上找到第一个白色像素点的位置作为该anchor的上下边界;
  • 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回
def generate_gt_anchor(img, box, anchor_width=16):"""calsulate ground truth fine-scale box:param img: input image:param box: ground truth box (4 point):param anchor_width::return: tuple (position, h, cy)"""if not isinstance(box[0], float):box = [float(box[i]) for i in range(len(box))]result = []# 求解一个bbox下,能分解为多少个16宽度的小anchor,并求出最左和最右的小achor的idleft_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width))  # the left side anchor of the text box, downwardsright_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width))  # the right side anchor of the text box, upwards# handle extreme case, the right side anchor may exceed the image widthif right_anchor_num * 16 + 15 > img.shape[1]:right_anchor_num -= 1# combine the left-side and the right-side x_coordinate of a text anchor into one pairposition_pair = [(i * anchor_width, (i + 1) * anchor_width - 1) for i in range(left_anchor_num, right_anchor_num)]# 计算每个gt anchor的真实位置,其实就是求解gt anchor的上边界和下边界y_top, y_bottom = cal_y_top_and_bottom(img, position_pair, box)# 最后将每个anchor的位置(水平ID)、anchor中心y坐标、anchor高度存储并返回for i in range(len(position_pair)):position = int(position_pair[i][0] / anchor_width)  # the index of anchor boxh = y_bottom[i] - y_top[i] + 1  # the height of anchor boxcy = (float(y_bottom[i]) + float(y_top[i])) / 2.0  # the center point of anchor boxresult.append((position, cy, h))return result

  计算anchor上下边界的方法:

# cal the gt anchor box's bottom and top coordinate
def cal_y_top_and_bottom(raw_img, position_pair, box):""":param raw_img::param position_pair: for example:[(0, 15), (16, 31), ...]:param box: gt box (4 point):return: top and bottom coordinates for y-axis"""img = copy.deepcopy(raw_img)y_top = []y_bottom = []height = img.shape[0]# 设置图像mask,channel 0为全黑图for i in range(img.shape[0]):for j in range(img.shape[1]):img[i, j, 0] = 0top_flag = Falsebottom_flag = False# 根据bbox四点画出文本框,channel 0下文本框为白色img = other.draw_box_4pt(img, box, color=(255, 0, 0))for k in range(len(position_pair)):# 从左到右遍历anchor gt,对每个anchor从上往下扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的上边界# calc top y coordinatefor y in range(0, height-1):# loop each anchor, from left to rightfor x in range(position_pair[k][0], position_pair[k][1] + 1):if img[y, x, 0] == 255:y_top.append(y)top_flag = Truebreakif top_flag is True:break# 从左到右遍历anchor gt,对每个anchor从下往上扫描像素,遇到白色像素点(255)就停下来,此时像素点坐标y就是该anchor gt的下边界# calc bottom y coordinate, pixel from down to top loopfor y in range(height - 1, -1, -1):# loop each anchor, from left to rightfor x in range(position_pair[k][0], position_pair[k][1] + 1):if img[y, x, 0] == 255:y_bottom.append(y)bottom_flag = Truebreakif bottom_flag is True:breaktop_flag = Falsebottom_flag = Falsereturn y_top, y_bottom

  经过上面的标签处理,我们已经将原先的标准的文本框标签转化为一个一个小尺度anchor标签,以下是标签转化后的效果:

 

以上标签可视化后看来anchor标签做得不错,但是这里需要提出的是,我发现这种anchor生成方法是不太精准的,比如一个文本框边缘像素刚好落在一个新的anchor上,那么我们就要为这个像素分配一个16像素的anchor,显然导致了文本框标签的不准确,引入了15像素的误差,这个是需要思考的。这个问题我们先不做处理,继续下面的工作。

当然转化期间我们也遇到很多奇怪的问题,比如下图这种标签都已经超出图像范围的,我们必须做相应的特殊处理,比如限定标签横坐标的最大尺寸为图像宽度。

left_anchor_num = int(math.floor(max(min(box[0], box[6]), 0) / anchor_width))  # the left side anchor of the text box, downwards
right_anchor_num = int(math.ceil(min(max(box[2], box[4]), img.shape[1]) / anchor_width))  # the right side anchor of the text box, upwards

  

训练过程:

练:优化器我们选择SGD,learning rate我们设置了两个,前N个epoch使用较大的lr,后面的epoch使用较小的lr以更好地收敛。

训练过程我们定义了4个loss,分别是total_cls_loss,total_v_reg_loss, total_o_reg_loss, total_loss(前面三个loss相加)。

 net = Net.CTPN() # 获取网络结构for name, value in net.named_parameters():if name in no_grad:value.requires_grad = Falseelse:value.requires_grad = True# for name, value in net.named_parameters():#     print('name: {0}, grad: {1}'.format(name, value.requires_grad))net.load_state_dict(torch.load('./lib/vgg16.model'))# net.load_state_dict(model_zoo.load_url(model_urls['vgg16']))lib.utils.init_weight(net)if using_cuda:net.cuda()net.train()print(net)criterion = Loss.CTPN_Loss(using_cuda=using_cuda)  # 获取losstrain_im_list, train_gt_list, val_im_list, val_gt_list = create_train_val()  # 获取训练、测试数据total_iter = len(train_im_list)print("total training image num is %s" % len(train_im_list))print("total val image num is %s" % len(val_im_list))train_loss_list = []test_loss_list = []# 开始迭代训练for i in range(epoch):if i >= change_epoch:lr = lr_behindelse:lr = lr_frontoptimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005)#optimizer = optim.Adam(net.parameters(), lr=lr)iteration = 1total_loss = 0total_cls_loss = 0total_v_reg_loss = 0total_o_reg_loss = 0start_time = time.time()random.shuffle(train_im_list)  # 打乱训练集# print(random_im_list)for im in train_im_list:root, file_name = os.path.split(im)root, _ = os.path.split(root)name, _ = os.path.splitext(file_name)gt_name = 'gt_' + name + '.txt'gt_path = os.path.join(root, "train_gt", gt_name)if not os.path.exists(gt_path):print('Ground truth file of image {0} not exists.'.format(im))continuegt_txt = lib.dataset_handler.read_gt_file(gt_path)  # 读取对应的标签#print("processing image %s" % os.path.join(img_root1, im))img = cv2.imread(im)if img is None:iteration += 1continueimg, gt_txt = lib.dataset_handler.scale_img(img, gt_txt)  # 图像和标签做归一化tensor_img = img[np.newaxis, :, :, :]tensor_img = tensor_img.transpose((0, 3, 1, 2))if using_cuda:tensor_img = torch.FloatTensor(tensor_img).cuda()else:tensor_img = torch.FloatTensor(tensor_img)vertical_pred, score, side_refinement = net(tensor_img)  # 正向计算,获取预测结果del tensor_img# transform bbox gt to anchor gt for trainingpositive = []negative = []vertical_reg = []side_refinement_reg = []visual_img = copy.deepcopy(img)  # 该图用于可视化标签try:# loop all bbox in one imagefor box in gt_txt:# generate anchors from one bboxgt_anchor, visual_img = lib.generate_gt_anchor.generate_gt_anchor(img, box, draw_img_gt=visual_img)  # 获取图像的anchor标签positive1, negative1, vertical_reg1, side_refinement_reg1 = lib.tag_anchor.tag_anchor(gt_anchor, score, box) # 计算预测值反映在anchor层面的数据positive += positive1negative += negative1vertical_reg += vertical_reg1side_refinement_reg += side_refinement_reg1except:print("warning: img %s raise error!" % im)iteration += 1continueif len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:iteration += 1continuecv2.imwrite(os.path.join(DRAW_PREFIX, file_name), visual_img)optimizer.zero_grad()# 计算误差loss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,negative, vertical_reg, side_refinement_reg)# 反向传播                                                   loss.backward()optimizer.step()iteration += 1# save gpu memory by transferring loss to floattotal_loss += float(loss)total_cls_loss += float(cls_loss)total_v_reg_loss += float(v_reg_loss)total_o_reg_loss += float(o_reg_loss)if iteration % display_iter == 0:end_time = time.time()total_time = end_time - start_timeprint('Epoch: {2}/{3}, Iteration: {0}/{1}, loss: {4}, cls_loss: {5}, v_reg_loss: {6}, o_reg_loss: {7}, {8}'.format(iteration, total_iter, i, epoch, total_loss / display_iter, total_cls_loss / display_iter,total_v_reg_loss / display_iter, total_o_reg_loss / display_iter, im))logger.info('Epoch: {2}/{3}, Iteration: {0}/{1}'.format(iteration, total_iter, i, epoch))logger.info('loss: {0}'.format(total_loss / display_iter))logger.info('classification loss: {0}'.format(total_cls_loss / display_iter))logger.info('vertical regression loss: {0}'.format(total_v_reg_loss / display_iter))logger.info('side-refinement regression loss: {0}'.format(total_o_reg_loss / display_iter))train_loss_list.append(total_loss)total_loss = 0total_cls_loss = 0total_v_reg_loss = 0total_o_reg_loss = 0start_time = time.time()# 定期验证模型性能if iteration % val_iter == 0:net.eval()logger.info('Start evaluate at {0} epoch {1} iteration.'.format(i, iteration))val_loss = evaluate.val(net, criterion, val_batch_size, using_cuda, logger, val_im_list)logger.info('End evaluate.')net.train()start_time = time.time()test_loss_list.append(val_loss)# 定期存储模型if iteration % save_iter == 0:print('Model saved at ./model/ctpn-{0}-{1}.model'.format(i, iteration))torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-{1}.model'.format(i, iteration))print('Model saved at ./model/ctpn-{0}-end.model'.format(i))torch.save(net.state_dict(), './model/ctpn-msra_ali-{0}-end.model'.format(i))# 画出loss的变化图draw_loss_plot(train_loss_list, test_loss_list)

缩放图像具有一定规则:首先要保证文本框label的最短边也要等于600。我们通过

  scale = float(shortest_side)/float(min(height, width))

来求得图像的缩放系数,对原始图像进行缩放。 同时我们也要对我们的label也要根据该缩放系数进行缩放。

 

def scale_img(img, gt, shortest_side=600):height = img.shape[0]width = img.shape[1]scale = float(shortest_side)/float(min(height, width))img = cv2.resize(img, (0, 0), fx=scale, fy=scale)if img.shape[0] < img.shape[1] and img.shape[0] != 600:img = cv2.resize(img, (600, img.shape[1]))elif img.shape[0] > img.shape[1] and img.shape[1] != 600:img = cv2.resize(img, (img.shape[0], 600))elif img.shape[0] != 600:img = cv2.resize(img, (600, 600))h_scale = float(img.shape[0])/float(height)w_scale = float(img.shape[1])/float(width)scale_gt = []for box in gt:scale_box = []for i in range(len(box)):# x坐标if i % 2 == 0:scale_box.append(int(int(box[i]) * w_scale))# y坐标else:scale_box.append(int(int(box[i]) * h_scale))scale_gt.append(scale_box)return img, scale_gt

  验证集评估:

def val(net, criterion, batch_num, using_cuda, logger):img_root = '../dataset/OCR_dataset/ctpn/test_im'gt_root = '../dataset/OCR_dataset/ctpn/test_gt'img_list = os.listdir(img_root)total_loss = 0total_cls_loss = 0total_v_reg_loss = 0total_o_reg_loss = 0start_time = time.time()for im in random.sample(img_list, batch_num):name, _ = os.path.splitext(im)gt_name = 'gt_' + name + '.txt'gt_path = os.path.join(gt_root, gt_name)if not os.path.exists(gt_path):print('Ground truth file of image {0} not exists.'.format(im))continuegt_txt = Dataset.port.read_gt_file(gt_path, have_BOM=True)img = cv2.imread(os.path.join(img_root, im))img, gt_txt = Dataset.scale_img(img, gt_txt)tensor_img = img[np.newaxis, :, :, :]tensor_img = tensor_img.transpose((0, 3, 1, 2))if using_cuda:tensor_img = torch.FloatTensor(tensor_img).cuda()else:tensor_img = torch.FloatTensor(tensor_img)vertical_pred, score, side_refinement = net(tensor_img)del tensor_imgpositive = []negative = []vertical_reg = []side_refinement_reg = []for box in gt_txt:gt_anchor = Dataset.generate_gt_anchor(img, box)positive1, negative1, vertical_reg1, side_refinement_reg1 = Net.tag_anchor(gt_anchor, score, box)positive += positive1negative += negative1vertical_reg += vertical_reg1side_refinement_reg += side_refinement_reg1if len(vertical_reg) == 0 or len(positive) == 0 or len(side_refinement_reg) == 0:batch_num -= 1continueloss, cls_loss, v_reg_loss, o_reg_loss = criterion(score, vertical_pred, side_refinement, positive,negative, vertical_reg, side_refinement_reg)total_loss += losstotal_cls_loss += cls_losstotal_v_reg_loss += v_reg_losstotal_o_reg_loss += o_reg_lossend_time = time.time()total_time = end_time - start_timeprint('####################  Start evaluate  ####################')print('loss: {0}'.format(total_loss / float(batch_num)))logger.info('Evaluate loss: {0}'.format(total_loss / float(batch_num)))print('classification loss: {0}'.format(total_cls_loss / float(batch_num)))logger.info('Evaluate vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))print('vertical regression loss: {0}'.format(total_v_reg_loss / float(batch_num)))logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))print('side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))logger.info('Evaluate side-refinement regression loss: {0}'.format(total_o_reg_loss / float(batch_num)))print('{1} iterations for {0} seconds.'.format(total_time, batch_num))print('#####################  Evaluate end  #####################')print('\n')

  

训练效果与预测效果

测试效果:输入一张图片,给出最后的检测结果

def infer_one(im_name, net):im = cv2.imread(im_name)im = lib.dataset_handler.scale_img_only(im)  # 归一化图像img = copy.deepcopy(im)img = img.transpose(2, 0, 1)img = img[np.newaxis, :, :, :]img = torch.Tensor(img)v, score, side = net(img, val=True)  # 送入网络预测result = []# 根据分数获取有文字的anchorfor i in range(score.shape[0]):for j in range(score.shape[1]):for k in range(score.shape[2]):if score[i, j, k, 1] > THRESH_HOLD:result.append((j, k, i, float(score[i, j, k, 1].detach().numpy())))# nms过滤for_nms = []for box in result:pt = lib.utils.trans_to_2pt(box[1], box[0] * 16 + 7.5, anchor_height[box[2]])for_nms.append([pt[0], pt[1], pt[2], pt[3], box[3], box[0], box[1], box[2]])for_nms = np.array(for_nms, dtype=np.float32)nms_result = lib.nms.cpu_nms(for_nms, NMS_THRESH)out_nms = []for i in nms_result:out_nms.append(for_nms[i, 0:8])# 确定哪几个anchors是属于一组的connect = get_successions(v, out_nms)# 将一组anchors合并成一条文本线texts = get_text_lines(connect, im.shape)for box in texts:box = np.array(box)print(box)lib.draw_image.draw_ploy_4pt(im, box[0:8])_, basename = os.path.split(im_name)cv2.imwrite('./infer_'+basename, im)

  

推断时提到了get_successions用于获取一个预测文本行里的所有anchors,换句话说,我们得到的很多预测有字符的anchor,但是我们怎么知道哪些acnhors可以组成一个文本线呢?所以我们需要实现一个anchor合并算法,这也是CTPN代码实现中最为困难的一步。

CTPN论文提到,文本线构造法如下:文本行构建很简单,通过将那些text/no-text score > 0.7的连续的text proposals相连接即可。文本行的构建如下。

  • 首先,为一个proposal Bi定义一个邻居(Bj):Bj−>Bi,其中:
  1. Bj在水平距离上离Bi最近
  2. 该距离小于50 pixels
  • 它们的垂直重叠(vertical overlap) > 0.7

一看理论很简单,但是一到自己实现就困难重重了。真是应了那句“纸上得来终觉浅,绝知此事要躬行”啊!get_successions传入的参数是v代表每个预测anchor的h和y信息,anchors代表每个anchors的四个顶点坐标信息。

检测效果和总结

首先看一下训练出来的模型的文字检测效果,为了便于观察,我把anchor和最终合并好的文本框一并画出:

 

 

在实现过程中的一些总结和想法:

  1. CTPN对于带旋转角度的文本的检测效果不好,其实这是CTPN的算法特点决定的:一个个固定宽度的四边形是很难合并出一个准确的文本框,比如一些anchors很难组成一组,即使组成一组了也很难精确恢复成完整的精确的文本矩形框(推断阶段的缺点)。当然啦,对于水平排布的文本检测,个人认为这个算法思路还是很奏效的。
  2. CTPN中的side-refinement其实作用不大,如果我们检测出来的文本是直接拿出识别,这个side-refinement优化的几个像素差别其实可以忽略;
  3. CTPN的中间步骤有点多:从anchor标签的生成到中间计算loss再到最后推断的文本线生成步骤,都会引入一定的误差,这个缺点也是EAST论文中所提出的。训练的步骤越简洁,中间过程越少,精度更有保障。
  4. CTPN的算法得出的效果可以看出,准确率低但召回率高。这种基于16像素的anchor识别感觉对于一些大的非文字图标(比如路标)误判率相当高,这是源于其anchor的宽度实在太小了,尽管使用了lstm关联周围anchor,但是我还是认为有点“一叶障目”的感觉。所以CTPN对于过大或过小的文字检测效果不会太好。
  5. CTPN是个比较老的算法了(2016年),其思路在当年还是很创新的,但是也有很多弊端。现在提出的新方法已经基本解决了这些不足之处,比如EAST,PixelNet都是一些很优秀的新算法。

CTPN的完整实现可以参考博主:   Github

 

.html

 

转载于:.html

更多推荐

ICDAR2015 数据处理及训练

本文发布于:2024-02-12 14:29:34,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1688170.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:数据处理

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!