YOLO在检测物体后会画出ROI,并在左上角标注类别。 如YOLO官网给出的预测图: 这里的物体类别描述是从配置文件中读出并画出的。想要修改让其支持中文,乍一看好像很简单,貌似只需要修改对应的配置文件就行了。理想很丰满,现实很骨感,如果这样做了,只能得到一堆乱码。 我们一步步阅读代码,来看看为什么。
我工作中使用YOLO主要是做物体的实时检测,所以以实时检测为入口分析代码。 执行命令
./darknet detector demo cfg/coco.data cfg/yolo.cfg yolo.weights从darknet.c的main函数看起
else if (0 == strcmp(argv[1], "detector")){ run_detector(argc, argv);进入detector.c查看run_detector函数
else if(0==strcmp(argv[2], "demo")) { list *options = read_data_cfg(datacfg); int classes = option_find_int(options, "classes", 20); char *name_list = option_find_str(options, "names", "data/names.list"); char **names = get_labels(name_list); demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, hier_thresh); }coco.data
names = data/coco.namescoco.names
person bicycle car motorbike aeroplane bus train truck boat ......物体类别信息就是从这得来的,是不是也从这画,我们接着看。 demo.c
void demo(char *cfgfile, char *weightfile, float thresh, int cam_index, const char *filename, char **names, int classes, int frame_skip, char *prefix, float hier_thresh) { //skip = frame_skip; image **alphabet = load_alphabet(); int delay = frame_skip; demo_names = names; demo_alphabet = alphabet; demo_classes = classes; demo_thresh = thresh; demo_hier_thresh = hier_thresh; printf("Demo\n"); net = parse_network_cfg(cfgfile); ...... while(1){ ++count; if(1){ if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed"); if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed"); ......demo_alphabet在绘制ROI时起到重要作用,稍后会做分析。 fetch_thread负责从camera读图像数据,detect_thread负责从检测物体。
void *detect_in_thread(void *ptr) { float nms = .4; layer l = net.layers[net.n-1]; float *X = det_s.data; float *prediction = network_predict(net, X); memcpy(predictions[demo_index], prediction, l.outputs*sizeof(float)); mean_arrays(predictions, FRAMES, l.outputs, avg); l.output = avg; free_image(det_s); if(l.type == DETECTION){ get_detection_boxes(l, 1, 1, demo_thresh, probs, boxes, 0); } else if (l.type == REGION){ get_region_boxes(l, 1, 1, demo_thresh, probs, boxes, 0, 0, demo_hier_thresh); } else { error("Last layer must produce detections\n"); } if (l.softmax_tree && nms) do_nms_obj(boxes, probs, l.w*l.h*l.n, l.classes, nms); else if (nms) do_nms(boxes, probs, l.w*l.h*l.n, l.classes, nms); printf("\033[2J"); printf("\033[1;1H"); printf("\nFPS:%.1f\n",fps); printf("Objects:\n\n"); images[demo_index] = det; det = images[(demo_index + FRAMES/2 + 1)%FRAMES]; demo_index = (demo_index + 1)%FRAMES; draw_detections(det, l.w*l.h*l.n, demo_thresh, boxes, probs, demo_names, demo_alphabet, demo_classes); return 0;} 检测并画出ROI的核心流程: 1.network_predict 预测 2.get_region_boxes 获取ROI 3.do_nms 非极大值抑制 4.draw_detections 画出ROI
关于前三个环节的具体实现,以后会专门写一系列YOLO代码解析的博客做详细分析。 看一下本篇的关键函数draw_detections image.c
void draw_detections(image im, int num, float thresh, box *boxes, float **probs, char **names, image **alphabet, int classes) { int i; for(i = 0; i < num; ++i){ int class = max_index(probs[i], classes); float prob = probs[i][class]; if(prob > thresh){ int width = im.h * .012; if(0){ width = pow(prob, 1./2.)*10+1; alphabet = 0; } printf("%s: %.0f%%\n", names[class], prob*100); int offset = class*123457 % classes; float red = get_color(2,offset,classes); float green = get_color(1,offset,classes); float blue = get_color(0,offset,classes); float rgb[3]; //width = prob*20+2; rgb[0] = red; rgb[1] = green; rgb[2] = blue; box b = boxes[i]; int left = (b.x-b.w/2.)*im.w; int right = (b.x+b.w/2.)*im.w; int top = (b.y-b.h/2.)*im.h; int bot = (b.y+b.h/2.)*im.h; if(left < 0) left = 0; if(right > im.w-1) right = im.w-1; if(top < 0) top = 0; if(bot > im.h-1) bot = im.h-1; draw_box_width(im, left, top, right, bot, width, red, green, blue); if (alphabet) { image label = get_label(alphabet, names[class], (im.h*.03)/10); draw_label(im, top + width, left, label, rgb); } } } }YOLOv2中每一帧image要预测出13*13*5个物体,在绘制ROI时只绘制概率大于threshold的ROI。 probs里保存着每13*13*5个object是80个class中每种的概率。举个例子,probs[i][0]是预测的第i个物体是person的概率,probs[i][1]是该物体是bicycle的概率,probs[i][2]是该物体是car的概率。 max_index函数会找出object最可能属于哪个类别。 draw_box_width会画出ROI,针对中文支持,我们最关心的是函数get_label和draw_label。
if (alphabet) { image label = get_label(alphabet, names[class], (im.h*.03)/10); draw_label(im, top + width, left, label, rgb); }我们先看下alphabet是什么。 在demo函数开始时曾执行过load_alphabet。
image **load_alphabet() { int i, j; const int nsize = 8; image **alphabets = calloc(nsize, sizeof(image)); for(j = 0; j < nsize; ++j){ alphabets[j] = calloc(128, sizeof(image)); for(i = 32; i < 127; ++i){ char buff[256]; sprintf(buff, "data/labels/%d_%d.png", i, j); alphabets[j][i] = load_image_color(buff, 0, 0); } } return alphabets; }data/labels目录下是ASCII 32~126的九十多个字符8种尺寸的png图片。i代表字符对于的ASCII码,j是尺寸。 alphabet里保存了九十多个字符的8种尺寸的image。
image get_label(image **characters, char *string, int size) { if(size > 7) size = 7; image label = make_empty_image(0,0,0); while(*string){ image l = characters[size][(int)*string]; image n = tile_images(label, l, -size - 1 + (size+1)/2); free_image(label); label = n; ++string; } image b = border_image(label, label.h*.25); free_image(label); return b; } void draw_label(image a, int r, int c, image label, const float *rgb) { int w = label.w; int h = label.h; if (r - h >= 0) r = r - h; int i, j, k; for(j = 0; j < h && j + r < a.h; ++j){ for(i = 0; i < w && i + c < a.w; ++i){ for(k = 0; k < label.c; ++k){ float val = get_pixel(label, i, j, k); set_pixel(a, i+c, j+r, k, rgb[k] * val); } } } }YOLO绘制ROI的标签信息就是加载每个字符对应的image,想要加入对中文的支持,首先要绘制中文标签对应的png图片。(我曾想改变绘制便签信息的机制,直接使用opencv的cvPutText函数绘制标签,然而cvPutText也不支持中文,需要添加字体才行,而这又增加了依赖。)
darknet提供了制作png图片的脚本,data/labels/make_labels.py。 稍微修改一下脚本即可:
# -*- coding: utf-8 -*- import os #l=[] #with open("coco.names") as list_in: # for line in list_in: # l.append(line) l=["人","自行车","车","摩托车","飞机","大巴","火车","卡车","船","交通灯","消防栓","停止标识","停车计时器","长凳","鸟","猫","狗","马","羊","牛","大象","熊","斑马","长颈鹿","背包","伞","手提包","领带","手提箱","飞盘","雪橇","滑雪板","体育用球","风筝","棒球棒","棒球手套","滑板","冲浪板","网球拍","瓶子","红酒杯","杯子","叉子","小刀","勺子","碗","香蕉","苹果","三明治","橘子","西兰花","萝卜","热狗","披萨","甜甜圈","蛋糕","椅子","沙发","盆栽","床","餐桌","厕所","显示器","笔记本","鼠标","遥控","键盘","手机","微波炉","烤箱","吐司机","水槽","冰箱","书","闹钟","花瓶","剪刀","泰迪熊","吹风机","牙刷"] def make_labels(s): i = 0 for word in l: os.system("convert -fill black -background white -bordercolor white -border 4 -font /usr/share/fonts/truetype/arphic/ukai.ttc -pointsize %d label:\"%s\" \"cn_%d_%d.png\""%(s,word,i,s/12-1)) i = i + 1 for i in [12,24,36,48,60,72,84,96]: make_labels(i)生成png图片后需要修改image.c 具体修改见GitHub https://github.com/PaulChongPeng/darknet/commit/798fe7cc4176d452a83d63eb261d6129e397a521
#include "opencv2/highgui/highgui_c.h" #include "opencv2/imgproc/imgproc_c.h" #include "opencv2/videoio/videoio_c.h" +#define CHINESE #endif @@ -88,6 +89,23 @@ image get_label(image **characters, char *string, int size) return b; } +#ifdef CHINESE +image get_label_chinese(image **characters, int class, int size) +{ + if(size > 7) size = 7; + image label = make_empty_image(0,0,0); + + image l = characters[size][class]; + image n = tile_images(label, l, -size - 1 + (size+1)/2); + free_image(label); + label = n; + + image b = border_image(label, label.h*.25); + free_image(label); + return b; +} +#endif + void draw_label(image a, int r, int c, image label, const float *rgb) { int w = label.w; @@ -169,11 +187,19 @@ image **load_alphabet() image **alphabets = calloc(nsize, sizeof(image)); for(j = 0; j < nsize; ++j){ alphabets[j] = calloc(128, sizeof(image)); +#ifdef CHINESE + for(i = 0; i < 80; ++i){ + char buff[256]; + sprintf(buff, "data/labels/byd_%d_%d.png", i, j); + alphabets[j][i] = load_image_color(buff, 0, 0); + } +#else for(i = 32; i < 127; ++i){ char buff[256]; sprintf(buff, "data/labels/%d_%d.png", i, j); alphabets[j][i] = load_image_color(buff, 0, 0); } +#endif } return alphabets; } @@ -220,7 +246,11 @@ void draw_detections(image im, int num, float thresh, box *boxes, float **probs, draw_box_width(im, left, top, right, bot, width, red, green, blue); if (alphabet) { +#ifdef CHINESE + image label = get_label_chinese(alphabet, class, (im.h*.03)/10); +#else image label = get_label(alphabet, names[class], (im.h*.03)/10); +#endif draw_label(im, top + width, left, label, rgb); } }