tensoflow_yolov3 计算平均识别个数(平均识别数)
生活随笔
收集整理的这篇文章主要介绍了
tensoflow_yolov3 计算平均识别个数(平均识别数)
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
# -*- coding: utf-8 -*-
"""
@File : 20200221_Target_Recognition_光照度对模型识别率影响(计算平均识别个数).py
@Time : 2020/2/21 11:07
@Author : Dontla
@Email : sxana@qq.com
@Software: PyCharm
"""import tracebackimport cv2
import numpy as np
import tensorflow as tf
import core.utils as utils
from core.config import cfg
from core.yolov3 import YOLOV3
import pyrealsense2 as rs
import time
import sysclass YoloTest(object):def __init__(self):# D·C 191111:__C.TEST.INPUT_SIZE = 544self.input_size = cfg.TEST.INPUT_SIZEself.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE# Dontla 191106注释:初始化class.names文件的字典信息属性self.classes = utils.read_class_names(cfg.YOLO.CLASSES)# D·C 191115:类数量属性self.num_classes = len(self.classes)self.anchors = np.array(utils.get_anchors(cfg.YOLO.ANCHORS))# D·C 191111:__C.TEST.SCORE_THRESHOLD = 0.3self.score_threshold = cfg.TEST.SCORE_THRESHOLD# D·C 191120:__C.TEST.IOU_THRESHOLD = 0.45self.iou_threshold = cfg.TEST.IOU_THRESHOLDself.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY# D·C 191120:__C.TEST.ANNOT_PATH = "./data/dataset/Dontla/20191023_Artificial_Flower/test.txt"self.annotation_path = cfg.TEST.ANNOT_PATH# D·C 191120:__C.TEST.WEIGHT_FILE = "./checkpoint/f_g_c_weights_files/yolov3_test_loss=15.8845.ckpt-47"self.weight_file = cfg.TEST.WEIGHT_FILE# D·C 191115:可写标记(bool类型值)self.write_image = cfg.TEST.WRITE_IMAGE# D·C 191115:__C.TEST.WRITE_IMAGE_PATH = "./data/detection/"(识别图片画框并标注文本后写入的图片路径)self.write_image_path = cfg.TEST.WRITE_IMAGE_PATH# D·C 191116:TEST.SHOW_LABEL设置为Trueself.show_label = cfg.TEST.SHOW_LABEL# D·C 191120:创建命名空间“input”with tf.name_scope('input'):# D·C 191120:建立变量(创建占位符开辟内存空间)self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')self.trainable = tf.placeholder(dtype=tf.bool, name='trainable')model = YOLOV3(self.input_data, self.trainable)self.pred_sbbox, self.pred_mbbox, self.pred_lbbox = model.pred_sbbox, model.pred_mbbox, model.pred_lbbox# D·C 191120:创建命名空间“指数滑动平均”with tf.name_scope('ema'):ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)# D·C 191120:在允许软设备放置的会话中启动图形并记录放置决策。(不懂啥意思。。。)allow_soft_placement=True表示允许tf自动选择可用的GPU和CPUself.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))# D·C 191120:variables_to_restore()用于加载模型计算滑动平均值时将影子变量直接映射到变量本身self.saver = tf.train.Saver(ema_obj.variables_to_restore())# D·C 191120:用于下次训练时恢复模型self.saver.restore(self.sess, self.weight_file)# 摄像头序列号# self.cam_serials = ['838212073161']# self.cam_serials = ['827312070790']self.cam_serials = ['836612072369']self.cam_num = len(self.cam_serials)def predict(self, image):# D·C 191107:复制一份图片的镜像,避免对图片直接操作改变图片的内在属性org_image = np.copy(image)# D·C 191107:获取图片尺寸org_h, org_w, _ = org_image.shape# D·C 191108:该函数将源图结合input_size,将其转换成预投喂的方形图像(作者默认544×544,中间为缩小尺寸的源图,上下空区域为灰图):image_data = utils.image_preprocess(image, [self.input_size, self.input_size])# D·C 191108:打印维度看看:# print(image_data.shape)# (544, 544, 3)# D·C 191108:创建新轴,不懂要创建新轴干嘛?image_data = image_data[np.newaxis, ...]# D·C 191108:打印维度看看:# print(image_data.shape)# (1, 544, 544, 3)# D·C 191110:三个box可能存放了预测框图(可能是N多的框,有用的没用的重叠的都在里面)的信息(但是打印出来的值完全看不懂啊喂?)pred_sbbox, pred_mbbox, pred_lbbox = self.sess.run([self.pred_sbbox, self.pred_mbbox, self.pred_lbbox],feed_dict={self.input_data: image_data,self.trainable: False})# D·C 191110:打印三个box的类型、形状和值看看:# print(type(pred_sbbox))# print(type(pred_mbbox))# print(type(pred_lbbox))# 都是<class 'numpy.ndarray'># print(pred_sbbox.shape)# print(pred_mbbox.shape)# print(pred_lbbox.shape)# (1, 68, 68, 3, 6)# (1, 34, 34, 3, 6)# (1, 17, 17, 3, 6)# print(pred_sbbox)# print(pred_mbbox)# print(pred_lbbox)# D·C 191110:(-1,6)表示不知道有多少行,反正你给我整成6列,然后concatenate又把它们仨给叠起来,最终得到无数个6列数组(后面self.num_classes)个数存放的貌似是这个框属于类的概率)pred_bbox = np.concatenate([np.reshape(pred_sbbox, (-1, 5 + self.num_classes)),np.reshape(pred_mbbox, (-1, 5 + self.num_classes)),np.reshape(pred_lbbox, (-1, 5 + self.num_classes))], axis=0)# D·C 191111:打印pred_bbox和它的维度看看:# print(pred_bbox)# print(pred_bbox.shape)# (18207, 6)# D·C 191111:猜测是第一道过滤,过滤掉score_threshold以下的图片,过滤完之后少了好多:# D·C 191115:bboxes维度为[n,6],前四列是坐标,第五列是得分,第六列是对应类下标bboxes = utils.postprocess_boxes(pred_bbox, (org_h, org_w), self.input_size, self.score_threshold)# D·C 191111:猜测是第二道过滤,过滤掉iou_threshold以下的图片:bboxes = utils.nms(bboxes, self.iou_threshold)return bboxesdef cam_conti_veri(self, cam_num, ctx):"""摄像头连续验证、连续验证机制"""# D·C 1911202:创建最大验证次数max_veri_times;创建连续稳定值continuous_stable_value,用于判断设备重置后是否处于稳定状态max_veri_times = 100continuous_stable_value = 5print('\n', end='')print('开始连续验证,连续验证稳定值:{},最大验证次数:{}:'.format(continuous_stable_value, max_veri_times))continuous_value = 0veri_times = 0while True:devices = ctx.query_devices()connected_cam_num = len(devices)print('摄像头个数:{}'.format(connected_cam_num))if connected_cam_num == cam_num:continuous_value += 1if continuous_value == continuous_stable_value:breakelse:continuous_value = 0veri_times += 1if veri_times == max_veri_times:print("检测超时,请检查摄像头连接!")sys.exit()def cam_hardware_reset(self, ctx, cam_serials):"""循环reset摄像头"""# hardware_reset()后是不是应该延迟一段时间?不延迟就会报错print('\n', end='')print('开始初始化摄像头:')for dev in ctx.query_devices():# 先将设备的序列号放进一个变量里,免得在下面for循环里访问设备的信息过多(虽然不知道它会不会每次都重新访问)dev_serial = dev.get_info(rs.camera_info.serial_number)# 匹配序列号,重置我们需重置的特定摄像头(注意两个for循环顺序,哪个在外哪个在内很重要,不然会导致刚重置的摄像头又被访问导致报错)for serial in cam_serials:if serial == dev_serial:dev.hardware_reset()# 像下面这条语句居然不会报错,不是刚刚才重置了dev吗?莫非区别在于没有通过for循环ctx.query_devices()去访问?# 是不是刚重置后可以通过ctx.query_devices()去查看有这个设备,但是却没有存储设备地址?如果是这样,# 也就能够解释为啥能够通过len(ctx.query_devices())函数获取设备数量,但访问序列号等信息就会报错的原因了print('摄像头{}初始化成功'.format(dev.get_info(rs.camera_info.serial_number)))# 如果只有一个摄像头,要让它睡够5秒(避免出错,保险起见)time.sleep(5 / len(cam_serials))def get_cam_serials(self):passdef calculate_detection_num(self, calcu_list, detect_num):"""计算一段次数内平均识别个数"""# 将列表calcu_list作为队列,右为头,左为尾,头为先进的帧,尾为后进的帧# 定义需做平均的队列帧数量frame_num = 50# 判断传进来的队列大小,如果小于frame_num就把元素添加到左边,如果大于或等于50,就把右边超过50的咔掉,并抛出最右边那个,将元素加到最左边if len(calcu_list) < frame_num:calcu_list.insert(0, detect_num)else:calcu_list = calcu_list[:frame_num]calcu_list.pop()calcu_list.insert(0, detect_num)# if len(calcu_list) > frame_num:# calcu_list = calcu_list[:frame_num]# elif len(calcu_list)==frame_num:# 求列表均值average_num = np.mean(calcu_list)return calcu_list, average_numdef dontla_evaluate_detect(self):# 摄像头个数(在这里设置所需使用摄像头的总个数)ctx = rs.context()# 连续验证机制self.cam_conti_veri(self.cam_num, ctx)# 循环reset摄像头self.cam_hardware_reset(ctx, self.cam_serials)# 连续验证机制self.cam_conti_veri(self.cam_num, ctx)# 打印摄像头序列号和接口号并创建需要显示在窗口上的备注信息字符串列表(窗口名)print('\n', end='')cam_id = 0serial_list = []for i in ctx.query_devices():cam_id += 1serial_list.append('camera{}; serials number {}; usb port {}'.format(cam_id, i.get_info(rs.camera_info.serial_number),i.get_info(rs.camera_info.usb_type_descriptor)))print('serial number {}:{};usb port:{}'.format(cam_id, i.get_info(rs.camera_info.serial_number),i.get_info(rs.camera_info.usb_type_descriptor)))# print(serial_list)# 配置各个摄像头的基本对象for i in range(self.cam_num):# D·C 191203:括号里是否有必要加ctx,加了没加好像没多大区别,但不加它又会提示黄色locals()['pipeline' + str(i)] = rs.pipeline(ctx)locals()['config' + str(i)] = rs.config()# Dontla 20200221 存疑,为何不以前面指定的摄像头序列号启动,而要重新获取序列号?locals()['serial' + str(i)] = ctx.devices[i].get_info(rs.camera_info.serial_number)locals()['config' + str(i)].enable_device(locals()['serial' + str(i)])locals()['config' + str(i)].enable_stream(rs.stream.depth, 640, 360, rs.format.z16, 30)locals()['config' + str(i)].enable_stream(rs.stream.color, 640, 360, rs.format.bgr8, 30)locals()['pipeline' + str(i)].start(locals()['config' + str(i)])# 创建对齐对象(深度对齐颜色)locals()['align' + str(i)] = rs.align(rs.stream.color)# 运行流并进行识别print('\n', end='')print('开始识别:')try:# 设置break标志,方便按下按钮跳出循环退出窗口break2 = False# 初始化计数列表calcu_list = []while True:for i in range(self.cam_num):locals()['frames' + str(i)] = locals()['pipeline' + str(i)].wait_for_frames()# 获取对齐帧集locals()['aligned_frames' + str(i)] = locals()['align' + str(i)].process(locals()['frames' + str(i)])# 获取对齐后的深度帧和彩色帧locals()['aligned_depth_frame' + str(i)] = locals()['aligned_frames' + str(i)].get_depth_frame()locals()['color_frame' + str(i)] = locals()['aligned_frames' + str(i)].get_color_frame()if not locals()['aligned_depth_frame' + str(i)] or not locals()['color_frame' + str(i)]:continue# 获取颜色帧内参locals()['color_profile' + str(i)] = locals()['color_frame' + str(i)].get_profile()locals()['cvsprofile' + str(i)] = rs.video_stream_profile(locals()['color_profile' + str(i)])locals()['color_intrin' + str(i)] = locals()['cvsprofile' + str(i)].get_intrinsics()locals()['color_intrin_part' + str(i)] = [locals()['color_intrin' + str(i)].ppx,locals()['color_intrin' + str(i)].ppy,locals()['color_intrin' + str(i)].fx,locals()['color_intrin' + str(i)].fy]locals()['color_image' + str(i)] = np.asanyarray(locals()['color_frame' + str(i)].get_data())locals()['bboxes_pr' + str(i)] = self.predict(locals()['color_image' + str(i)])# Dontla 20200221 打印识别个数# print(np.array(locals()['bboxes_pr' + str(i)]).shape)detect_num = len(locals()['bboxes_pr' + str(i)])# print('识别个数:{}'.format(detect_num))# Dontla 20200221 计算平均识别个数(这里只针对一个摄像头情况,多个摄像头到时再重构)calcu_list, mean_detect_num = self.calculate_detection_num(calcu_list, detect_num)# Dontla 20200221 打印平均识别个数print(calcu_list)print('平均识别个数:{}'.format(mean_detect_num))locals()['image' + str(i)] = utils.draw_bbox(locals()['color_image' + str(i)],locals()['bboxes_pr' + str(i)],locals()['aligned_depth_frame' + str(i)],locals()['color_intrin_part' + str(i)],show_label=self.show_label)# D·C 191202:本想创建固定比例的大小可调的窗口,发现无法使用,opencv bug?# cv2.namedWindow('{}'.format(serial_list[i]),# flags=cv2.WINDOW_NORMAL | cv2.WINDOW_FREERATIO | cv2.WINDOW_GUI_EXPANDED)cv2.imshow('{}'.format(serial_list[i]), locals()['image' + str(i)])key = cv2.waitKey(1)# 如果按下ESC,则跳出循环if key == 27:# 貌似直接用return也行# returnbreak2 = Truebreakif break2:breakexcept Exception as e:print("掉帧了!")traceback.print_exc()traceback.print_exc(file=open('traceback.txt', 'w+'))finally:# 大概觉得先关闭窗口再停止流比较靠谱# 销毁所有窗口cv2.destroyAllWindows()print('\n', end='')print('已关闭所有窗口!')# 停止所有流for i in range(self.cam_num):locals()['pipeline' + str(i)].stop()print('正在停止所有流,请等待数秒至程序稳定结束!')if __name__ == '__main__':YoloTest().dontla_evaluate_detect()print('程序已结束!')
总结
以上是生活随笔为你收集整理的tensoflow_yolov3 计算平均识别个数(平均识别数)的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: numpy报错:ModuleNotFou
- 下一篇: 【中级软考】软件质量模型的六大特性27个