object_detector.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch
  2. import numpy as np
  3. from models.experimental import attempt_load
  4. from utils.general import non_max_suppression, scale_coords, letterbox
  5. from utils.torch_utils import select_device
  6. import cv2
  7. from random import randint
  8. class Detector(object):
  9. ''' 预测 '''
  10. def __init__(self):
  11. ''' 初始化 '''
  12. self.img_size = 640
  13. self.threshold = 0.4
  14. self.max_frame = 160
  15. self.half = False
  16. if torch.cuda.is_available():
  17. self.half = True
  18. self.init_model()
  19. def init_model(self):
  20. self.weights = 'weights/final.pt'
  21. self.device = '0' if torch.cuda.is_available() else 'cpu'
  22. self.device = select_device(self.device)
  23. model = attempt_load(self.weights, map_location=self.device)
  24. model.to(self.device).eval()
  25. model.half() if self.half else model.float()
  26. # torch.save(model, 'test.pt')
  27. self.m = model
  28. self.names = model.module.names if hasattr(
  29. model, 'module') else model.names
  30. self.colors = [
  31. (randint(0, 255), randint(0, 255), randint(0, 255)) for _ in self.names
  32. ]
  33. def preprocess(self, img):
  34. img0 = img.copy()
  35. # 图像缩放到指定尺寸
  36. img = letterbox(img, new_shape=self.img_size)[0]
  37. # 从BGR转换为RGB, 通过transpose(2, 0, 1)将通道维度移动到最前面
  38. img = img[:, :, ::-1].transpose(2, 0, 1)
  39. img = np.ascontiguousarray(img)
  40. img = torch.from_numpy(img).to(self.device)
  41. img = img.half() # 半精度
  42. img /= 255.0 # 图像归一化
  43. if img.ndimension() == 3:
  44. img = img.unsqueeze(0)
  45. return img0, img
  46. def plot_bboxes(self, image, bboxes, line_thickness=None):
  47. ''' 画框
  48. Args: image: 图片
  49. bboxes: 框
  50. line_thickness: 线的厚度
  51. Returns: image: 画框后的图片
  52. '''
  53. tl = line_thickness or round(
  54. 0.002 * (image.shape[0] + image.shape[1]) / 2) + 1 # line/font thickness
  55. for (x1, y1, x2, y2, cls_id, conf) in bboxes:
  56. color = self.colors[self.names.index(cls_id)]
  57. c1, c2 = (x1, y1), (x2, y2)
  58. cv2.rectangle(image, c1, c2, color,
  59. thickness=tl, lineType=cv2.LINE_AA)
  60. tf = max(tl - 1, 1) # font thickness
  61. t_size = cv2.getTextSize(
  62. cls_id, 0, fontScale=tl / 3, thickness=tf)[0]
  63. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  64. cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA) # filled
  65. cv2.putText(image, '{} ID-{:.2f}'.format(cls_id, conf), (c1[0], c1[1] - 2), 0, tl / 3,
  66. [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  67. return image
  68. def detect(self, im):
  69. ''' 预测
  70. Args: im: 图片
  71. Returns: im: 预测后的图片
  72. '''
  73. im0, img = self.preprocess(im)
  74. pred = self.m(img, augment=False)[0]
  75. pred = pred.float()
  76. pred = non_max_suppression(pred, self.threshold, 0.3)
  77. pred_boxes = []
  78. image_info = {}
  79. count = 0
  80. for det in pred:
  81. if det is not None and len(det):
  82. det[:, :4] = scale_coords(
  83. img.shape[2:], det[:, :4], im0.shape).round()
  84. for *x, conf, cls_id in det:
  85. lbl = self.names[int(cls_id)]
  86. x1, y1 = int(x[0]), int(x[1])
  87. x2, y2 = int(x[2]), int(x[3])
  88. pred_boxes.append(
  89. (x1, y1, x2, y2, lbl, conf))
  90. count += 1
  91. key = '{}-{:02}'.format(lbl, count)
  92. image_info[key] = ['{}×{}'.format(
  93. x2-x1, y2-y1), np.round(float(conf), 3)]
  94. im = self.plot_bboxes(im, pred_boxes)
  95. return im, image_info