我在日常工作中经常使用PyQt和onnxruntime来快速生产demo软件,用于展示和测试,这里,我将以Yolov12为例,展示一下我的方案。
首先我们需要使用Yolov12训练一个模型,并export出Onnx文件,这个部分网络上有很多内容,可以使用ultralytics框架做这个事情,我在这里就不在赘述了,接下来的步骤直接从onnxruntime开始。
此处你需要针对你的模型去编写一个基于Onnxruntime的推理类,包括前处理,后处理,可视化等部分,做到输入是图片,输出是结果,要保证算法代码和软件代码的独立性。这里是我参考ultralytics框架写的一个yolov12的推理类,用于实例分割的推理。- import cv2
- import numpy as np
- import onnxruntime as ort
- import torch
- import yoloSeg.utils.ops as ops
- from yoloSeg.utils.results import Results
- class YOLOv12Seg:
- """
- YOLOv12 segmentation model for performing instance segmentation using ONNX Runtime.
- This class implements a YOLOv12 instance segmentation model using ONNX Runtime for inference. It handles
- preprocessing of input images, running inference with the ONNX model, and postprocessing the results to
- generate bounding boxes and segmentation masks.
- """
- def __init__(self, onnx_model, classes, conf=0.25, iou=0.7, imgsz=640):
- """
- Initialize the instance segmentation model using an ONNX model.
- """
- self.session = ort.InferenceSession(
- onnx_model,
- providers=["CPUExecutionProvider"]
- # if torch.cuda.is_available()
- # else ["CPUExecutionProvider"],
- )
- self.imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
- self.classes = classes
- self.conf = conf
- self.iou = iou
- def __call__(self, img):
- """
- Run inference on the input image using the ONNX model.
- """
- prep_img = self.preprocess(img, self.imgsz)
- outs = self.session.run(None, {self.session.get_inputs()[0].name: prep_img})
- return self.postprocess(img, prep_img, outs)
- def letterbox(self, img, new_shape=(640, 640)):
- """
- Resize and pad image while maintaining aspect ratio.
- """
- shape = img.shape[:2] # current shape [height, width]
- # Scale ratio (new / old)
- r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
- # Compute padding
- new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
- dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
- if shape[::-1] != new_unpad: # resize
- img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
- top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
- left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
- img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
- return img
- def preprocess(self, img, new_shape):
- """
- Preprocess the input image before feeding it into the model.
- """
- img = self.letterbox(img, new_shape)
- img = img[..., ::-1].transpose([2, 0, 1])[None] # BGR to RGB, BHWC to BCHW
- img = np.ascontiguousarray(img)
- img = img.astype(np.float32) / 255 # Normalize to [0, 1]
- return img
- def postprocess(self, img, prep_img, outs):
- """
- Post-process model predictions to extract meaningful results.
- """
- preds, protos = [torch.from_numpy(p) for p in outs]
- preds = ops.non_max_suppression(preds, self.conf, self.iou, nc=len(self.classes))
- results = []
- for i, pred in enumerate(preds):
- pred[:, :4] = ops.scale_boxes(prep_img.shape[2:], pred[:, :4], img.shape)
- masks = self.process_mask(protos[i], pred[:, 6:], pred[:, :4], img.shape[:2])
- results.append(Results(img, path="", names=self.classes, boxes=pred[:, :6], masks=masks))
- return results
- def process_mask(self, protos, masks_in, bboxes, shape):
- """
- Process prototype masks with predicted mask coefficients to generate instance segmentation masks.
- """
- c, mh, mw = protos.shape # CHW
- masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # Matrix multiplication
- masks = ops.scale_masks(masks[None], shape)[0] # Scale masks to original image size
- masks = ops.crop_mask(masks, bboxes) # Crop masks to bounding boxes
- return masks.gt_(0.0) # Convert to binary masks
- def visualize_segmentation(self, image, results, alpha=0.95):
- # 创建图像副本
- visualization = image.copy()
- # 获取预测结果
- if isinstance(results, list):
- result = results[0] # 只取第一个结果
- else:
- result = results
- # 检查是否有分割掩码
- if hasattr(result, 'masks') and result.masks is not None:
- # 获取边界框、类别和置信度
- boxes = result.boxes.cpu().numpy()
- masks = result.masks.data.cpu().numpy()
- # 生成随机颜色
- num_instances = len(boxes)
- colors = np.random.randint(0, 255, size=(num_instances, 3), dtype=np.uint8)
- # 遍历每个实例
- for i in range(num_instances):
- confidence = boxes.conf[i]
- class_id = int(boxes.cls[i])
- # 获取掩码并调整大小以匹配原始图像
- mask = masks[i]
- if mask.shape[:2] != image.shape[:2]:
- mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
- # 创建着色掩码
- color_mask = np.zeros_like(image)
- mask_bool = (mask > 0).astype(bool)
- color = colors[i].tolist()
- color_mask[mask_bool] = color
- # 将掩码与原始图像混合
- visualization = cv2.addWeighted(visualization, 1.0, color_mask, alpha, 0)
- # # 绘制边界框
- # x1, y1, x2, y2 = map(int, boxes.xyxy[i])
- # cv2.rectangle(visualization, (x1, y1), (x2, y2), color, 2)
- # # 获取类别名称
- # class_name = self.classes[class_id]
- # # 显示类别名称和置信度
- # label = f"{class_name}: {confidence:.2f}"
- # text_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
- # cv2.rectangle(visualization, (x1, y1 - text_size[1] - 5), (x1 + text_size[0], y1), color, -1)
- # cv2.putText(visualization, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
- return visualization
复制代码 完成算法模块的独立后,接下来就是基于PyQt的软件模块开发,此处之所以强调两者互相独立,也是为了提高软件代码后续的重复使用率,以后即使换了一个算法,只需要修改几行代码就可以快速的实现第二个demo软件的开发。那么因为我们是视觉项目,所以整个软件需要强调的是原图,结果图,以及一些相关的功能按钮,接下来我将提供一个软件模板和对应的注释,方便大家使用。- import sys
- import os
- from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
- QPushButton, QLabel, QFileDialog, QTabWidget, QScrollArea,
- QSplitter, QMessageBox)
- from PyQt5.QtGui import QPixmap, QImage
- from PyQt5.QtCore import Qt, pyqtSignal, QThread
- import cv2
- import numpy as np
- from yoloSeg.model import YOLOv12Seg as SegAI
- class ImageProcessor(QThread):
- """处理图像的线程,避免UI阻塞"""
- detection_completed = pyqtSignal(np.ndarray)
-
- def __init__(self):
- super().__init__()
- self.image = None
- self.ai = SegAI("weights/best.onnx", ["goldLine"], conf=0.2, iou=0.8)
-
- def set_images(self, image):
- self.image = image
-
- def run(self):
- result_image = self.detection_algorithm()
- self.detection_completed.emit(result_image)
-
- def detection_algorithm(self):
- results = self.ai(self.image)
- visualization = self.ai.visualize_segmentation(self.image, results)
- return visualization
- class ImageViewer(QLabel):
- """图像查看器组件"""
- def __init__(self):
- super().__init__()
- self.setAlignment(Qt.AlignCenter)
- self.setMinimumSize(300, 300)
- self.setStyleSheet("border: 1px solid gray; background-color: #f0f0f0;")
- self.setScaledContents(False)
- self.current_pixmap = None
-
- def setImage(self, image):
- if isinstance(image, np.ndarray):
- # 将OpenCV图像转换为QPixmap
- height, width, channel = image.shape
- bytesPerLine = 3 * width
- qImg = QImage(image.data, width, height, bytesPerLine, QImage.Format_RGB888).rgbSwapped()
- pixmap = QPixmap.fromImage(qImg)
- else:
- pixmap = QPixmap(image)
-
- if pixmap.isNull():
- return
- self.current_pixmap = pixmap
- self.updatePixmap()
-
- def updatePixmap(self):
- if self.current_pixmap:
- # 保持纵横比缩放到标签大小
- scaled_pixmap = self.current_pixmap.scaled(
- self.width(), self.height(),
- Qt.KeepAspectRatio, Qt.SmoothTransformation
- )
- super().setPixmap(scaled_pixmap)
-
- def resizeEvent(self, event):
- self.updatePixmap()
- super().resizeEvent(event)
- class ImageProcessingApp(QMainWindow):
- def __init__(self):
- super().__init__()
- self.image = None
- # 组件
- self.initUI()
- self.image_processor = ImageProcessor()
- self.image_processor.detection_completed.connect(self.on_detection_completed)
-
- def initUI(self):
- # 设置窗口标题和大小
- self.setWindowTitle('检测软件')
- self.setGeometry(100, 100, 1200, 700)
- # 创建中央部件和总体布局
- central_widget = QWidget()
- self.setCentralWidget(central_widget)
- main_layout = QVBoxLayout(central_widget)
- # 创建顶部按钮布局
- button_layout = QHBoxLayout()
- # 添加按钮
- self.btn_open = QPushButton('打开图片文件')
- self.btn_detection = QPushButton('开始检测')
- self.btn_save= QPushButton('保存图片')
- # 添加按钮到布局
- button_layout.addWidget(self.btn_open)
- button_layout.addWidget(self.btn_detection)
- button_layout.addWidget(self.btn_save)
- # 连接按钮信号
- self.btn_open.clicked.connect(self.open_file)
- self.btn_detection.clicked.connect(self.detect_image)
- self.btn_save.clicked.connect(self.save_image)
- # 禁用未加载图片前的按钮
- self.btn_save.setEnabled(False)
- # 添加分隔器分割左右区域
- splitter = QSplitter(Qt.Horizontal)
- # 左侧:图片展示区
- self.image_tabs = QTabWidget()
- self.image_viewer = ImageViewer()
- self.image_tabs.addTab(self.image_viewer, "图片队列")
- splitter.addWidget(self.image_tabs)
- # 右侧:结果显示区(使用标签页)
- self.result_tabs = QTabWidget()
- # 添加标签页
- self.detection_tab = ImageViewer()
- self.result_tabs.addTab(self.detection_tab, "图片检测")
- splitter.addWidget(self.result_tabs)
- # 设置分隔器比例
- splitter.setSizes([600, 600])
- # 添加按钮区域和分隔器到主布局
- main_layout.addLayout(button_layout)
- main_layout.addWidget(splitter, 1)
- # 添加状态栏
- self.statusBar().showMessage('就绪')
-
- def open_file(self):
- """打开文件夹并加载图像"""
- file_path, filetype = QFileDialog.getOpenFileName(self,
- "选取文件",
- os.getcwd(), # 起始路径
- "Image Files (*.jpg *.jpeg *.png *.bmp *.tif *.tiff)")
- if file_path:
- self.image = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)
- # 显示图片
- if self.image.any():
- self.image_viewer.setImage(self.image)
- else:
- QMessageBox.warning(self, '警告', '无法加载任何图像文件!')
-
- def detect_image(self):
- """检测图片"""
- self.statusBar().showMessage('正在进行图片检测...')
- # 使用线程处理图片检测
- self.image_processor.set_images(self.image)
- self.image_processor.start()
-
- def on_detection_completed(self, result_image):
- """检测完成后的回调"""
- self.detection_tab.setImage(result_image)
- self.result_tabs.setCurrentIndex(1) # 切换到图片检测标签页
- # 保存检测结果
- self.detection_result = result_image
- self.statusBar().showMessage('图片检测完成')
- self.btn_save.setEnabled(True) # 启用保存按钮
- def save_image(self):
- """保存最新阶段的图片"""
- cv2.imwrite("result.jpg", self.detection_result)
- self.statusBar().showMessage('图片保存完成')
- if __name__ == "__main__":
- app = QApplication(sys.argv)
- ex = ImageProcessingApp()
- ex.show()
- sys.exit(app.exec_())
-
复制代码 这里注意ImageProcessor类其实就是算法类在软件中的代理类,类似于一个协议,后续只需要按需求修改这个类和算法类即可。
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |