使用 YOLOv8 进行实例分割
源代码
本文的源代码参考了 Pysource,如果需要下载请到此博客处下载。
1. 安装 YOLOv8 环境
确保使用可用的 CUDA 环境以保证速度。
pip install ultralytics
pip install opencv-python
ultralytics
官方已经封装了各种 YOLOv8 的模型,我们不需要下载就可以使用,
2. YOLOv8 实例分割
我们先创建 yolo_segmentation.py
:
import numpy as np
from ultralytics import YOLO
from ultralytics.yolo.engine.results import Results
class YOLOSegmentation:
def __init__(self, model_path: str):
self.model = YOLO(model_path)
def detect(self, img: np.ndarray):
height, width, _ = img.shape
results = self.model.predict(source=img.copy(), save=False, save_txt=False)
result: Results = results[0]
segmentation_contours_idx = []
if result.boxes is None or result.masks is None:
return [], [], [], []
for seg in result.masks.xyn:
# contours
seg[:, 0] *= width
seg[:, 1] *= height
segment = np.array(seg, dtype=np.int32)
segmentation_contours_idx.append(segment)
bboxes = np.array(result.boxes.xyxy.cpu(), dtype="int")
# Get class ids
class_ids = np.array(result.boxes.cls.cpu(), dtype="int")
# Get scores
scores = np.array(result.boxes.conf.cpu(), dtype="float").round(2)
return bboxes, class_ids, segmentation_contours_idx, scores
然后创建 main.py
:
import cv2
from yolo_segmentation import YOLOSegmentation
img = cv2.imread("images/rugby.jpg")
img = cv2.resize(img, None, fx=0.7, fy=0.7)
ys = YOLOSegmentation("yolov8m-seg.pt")
bboxes, classes, segmentations, scores = ys.detect(img)
for bbox, class_id, seg, score in zip(bboxes, classes, segmentations, scores):
print("bbox:", bbox, "class id:", class_id, "seg:", seg, "score:", score)
x1, y1, x2, y2 = bbox
cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.polylines(img, [seg], True, (0, 0, 255), 4)
cv2.imwrite("res.jpg", img)
下面分别是实验用到的 rugby.jpg
和 basket.jpg
,
生成结果:
3. 视频示例
此外,这里还实现了一个视频示例,实时分割手机并标注:
import cv2
from yolo_segmentation import YOLOSegmentation
PHONE_CLASS_ID = 67
ys = YOLOSegmentation("yolov8m-seg.pt")
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
bboxes, classes, segmentations, scores = ys.detect(frame)
mask = frame.copy()
for bbox, class_id, seg, score in zip(bboxes, classes, segmentations, scores):
if class_id != PHONE_CLASS_ID:
continue
# print("bbox:", bbox, "class id:", class_id, "seg:", seg, "score:", score)
x1, y1, x2, y2 = bbox
cv2.rectangle(mask, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.polylines(mask, [seg], True, (0, 0, 255), 4)
cv2.fillPoly(mask, [seg], (0, 255, 0))
cv2.putText(
mask,
f"{score:.2f}",
(x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.9,
(36, 255, 12),
2,
)
frame = cv2.addWeighted(frame, 0.7, mask, 0.3, 0)
cv2.imshow("frame", frame)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()