绘制yolo-pose 的关键点-->标签可视化到图像

1 关键点参数说明

对yolo-pose 的原理不做介绍,主要针对其标签数据做可视化操作

用于训练YOLO姿势模型的数据集标签格式如下:

  1. 每个图像对应一个文本文件:数据集中的每个图像都有一个对应的文本文件,该文本文件与图像文件同名,扩展名为 “.txt”。
  2. 每行一个对象:文本文件中的每一行对应于图像中的一个对象实例。
  3. 每行包含的对象信息:每行包含关于对象实例的以下信息:
    1. 对象类别索引:一个整数,表示对象的类别(例如,0 代表人,1 代表汽车等)。
    2. 对象中心坐标:对象的中心 x 和 y 坐标,已归一化到 0 和 1 之间。
    3. 对象宽度和高度:对象的宽度和高度,已归一化到 0 和 1 之间。
    4. 对象关键点坐标:对象的关键点坐标,已归一化到 0 和 1 之间。

下面是 “姿势估计 “任务的标签格式示例:

二维关键点格式:

1
<class-index> <x> <y> <width> <height> <px1> <py1> <px2> <py2> ... <pxn> <pyn>

带 3D 关键点的格式(包括每个点的可见度)

1
<class-index> <x> <y> <width> <height> <px1> <py1> <p1-visibility> <px2> <py2> <p2-visibility> <pxn> <pyn> <pn-visibility>

在这种格式中, <class-index> 是对象的类别索引, <x> <y> <width> <height>边界框的坐标<px1> <py1> <px2> <py2> ... <pxn> <pyn> 是归一化的关键点坐标。可见度通道是可选的,但对注释闭塞的数据集很有用。

带 3D 关键点的格式如图所示

9f86271c-4fef-48bb-905b-d52238f01910

2 辅助信息讲解

针对yolo关键点的数据存储结构,我们需要明确以下辅助信息。

2.1 对象类别名称

即框的目标是什么?

位置:标注信息的第一位

意义:记录对象的索引值

若有多个类别,则设定类别列表,供后续使用

1
2
3
4
# 类别名称
class_names = [
'male', 'female'
]

注:两个类别的关键点数量应保持一致

2.2 关键点名称

每个点所指的是什么区域

按照自然顺序记录,标注内容需手动声明

位置:标注信息中去除前4位后,剩余内容皆为关键点数据

以YOLOv8-pose人体姿态估计为例,在COCO数据集上身体的每一个关节具有一个序号,共17个点:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
COCO_keypoint_indexes = {
0: 'nose',
1: 'left_eye',
2: 'right_eye',
3: 'left_ear',
4: 'right_ear',
5: 'left_shoulder',
6: 'right_shoulder',
7: 'left_elbow',
8: 'right_elbow',
9: 'left_wrist',
10: 'right_wrist',
11: 'left_hip',
12: 'right_hip',
13: 'left_knee',
14: 'right_knee',
15: 'left_ankle',
16: 'right_ankle'
}

2.3 骨架链接关系

有了关键点坐标信息,我们可以将它们进行连接展示

最终的结果是每个点至少有一条连线

数据的格式则为两两一组

即:当connections=((9, 7), (7, 5), (5, 6), (6, 8), (8, 10)),绘制了手臂。

当connections=((2, 4), (1, 3), (10, 8), (8, 6), (6, 5), (5, 7), (7, 9), (6, 12), (12, 14), (14, 16), (5, 11), (11, 13), (13, 15)),绘制了身体骨架。

3 代码思路构建

有了参数说明,可以针对性的编写可视化代码

首先要明确使用的绘制工具,当前计划使用opencv的库完成最终的绘制操作。

考虑到训练结束后需要将结果展示,则代码就要有一定的可兼容空间。

那么,图像数据需要转为tensor格式进行操作,以便后续调用修改。

下面整合程序流程:

f0a9769a-fce9-4416-af51-5c107b7e5e49

4 代码详解

以下代码按顺序存放在同一个文件中,命名为:view_yolo_keypoint.py

4.1 导入第三方库

1
2
3
import cv2
import numpy as np
import os

4.2 自定义配置

  • 图片和标签路径:可使用绝对路径和相对路径
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# ============================ 用户配置 ============================
# 图片和标签文件路径
image_path = r'images\test\feed_bare-13-_jpg.rf.970a0b8dabcb055f13c510925a60792b.jpg'
label_path = r'labels\test\feed_bare-13-_jpg.rf.970a0b8dabcb055f13c510925a60792b.txt'

# 骨架连接关系 (关键点索引从0开始)
# 例如:(0, 1) 表示关键点0和关键点1之间有一条线
skeleton = ((2, 4), (1, 3), (10, 8), (8, 6), (6, 5), (5, 7), (7, 9), (6, 12), (12, 14), (14, 16), (5, 11), (11, 13), (13, 15))

# 关键点名称 (按顺序对应标签文件中的关键点)
keypoint_names = [
'nose', # 0
'left_eye', # 1
'right_eye', # 2
'left_ear', # 3
'right_ear', # 4
'left_shoulder', # 5
'right_shoulder', # 6
'left_elbow', # 7
'right_elbow', # 8
'left_wrist', # 9
'right_wrist', # 10
'left_hip', # 11
'right_hip', # 12
'left_knee', # 13
'right_knee', # 14
'left_ankle', # 15
'right_ankle' # 16
]

# 类别名称
class_names = [
'male', 'female'
]
# ==================================================================

4.3 数据读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class ReadimgLabel:  # 方便管理零散函数
def __init__(self):
pass

@staticmethod
def get_img(image_path):
# 读取图片
print('image_path: ', image_path)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError(f"图片文件 {image_path} 不存在")
# 获取图片尺寸
height, width, _ = image.shape
print(f"图片尺寸: {width}x{height}")
return image

@staticmethod
def get_labels(label_path):
# 读取标签
if not os.path.exists(label_path):
raise FileNotFoundError(f"标签文件 {label_path} 不存在")
with open(label_path, 'r') as f:
lines = f.readlines() # 一次性读取文件中的 所有行,并将它们作为一个列表返回
return lines

@staticmethod
def format_labels(lines):
xywh = [] # 对象的中心坐标和宽高, 形状为 [x, y, w, h]
keypoint = [] # 关键点坐标, 形状为 [(px, py, visibility)]
objects = [] # 一个对象的标签, 形状为 [class_index,xywh, keypoint]
for line in lines:
line = line.strip().split(' ')
xywh = [float(line[1]), float(line[2]), float(line[3]), float(line[4])]
for i in range(5, len(line), 3):
keypoint.append([float(line[i]), float(line[i+1]), float(line[i+2])])
objects.append([line[0], xywh, keypoint])
return objects

4.4 数据整合方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class KeyPointDrawer:
"""
可视化关键点和骨架
传入参数:
image: 图片的tensor数据
labels: 标签数据, 形状为 [[class_index,xywh, keypoint]]
"""
def __init__(self, image, labels):
self.image = image
self.img_w, self.img_h = image.shape[1], image.shape[0]
self.labels = labels

def draw_one_object(self,class_names, one_object, keypoint_names, skeleton):
"""
绘制一个对象的关键点和骨架
传入参数:
one_object: 一个对象的标签, 形状为 [class_index,xywh, keypoint]
keypoint_names: 关键点名称, 形状为 [keypoint_name]
skeleton: 骨架连接关系, 形状为 [(keypoint_index1, keypoint_index2)]
"""
class_index, xywh, keypoint = one_object # 类别索引

# 设定显示颜色
np.random.seed(888) # 固定随机数种子 (修改种子以刷新颜色)
# 生成足够的颜色: 关键点数量 + 骨架线条数量 + 额外的一些用于对象
num_colors = len(keypoint_names) + len(skeleton) + 50
all_colors = np.random.randint(0, 255, size=(num_colors, 3)).tolist()

# 分离颜色: 关键点颜色 -> 骨架颜色 -> 对象颜色
kp_colors = all_colors[:len(keypoint_names)]
sk_colors = all_colors[len(keypoint_names):len(keypoint_names)+len(skeleton)]
obj_colors = all_colors[len(keypoint_names)+len(skeleton):]

# 根据对象在classname中的索引生成一个唯一的索引
obj_color_idx = int(class_index) % len(obj_colors)
box_color = [int(c) for c in obj_colors[obj_color_idx]]

# 绘制目标框
x, y, w, h = xywh
x1 = int((x - w / 2) * self.img_w)
y1 = int((y - h / 2) * self.img_h)
x2 = int((x + w / 2) * self.img_w)
y2 = int((y + h / 2) * self.img_h)
cv2.rectangle(self.image, (x1, y1), (x2, y2), box_color, 2)
# 显示目标类别
text_y = y1 - 5 if y1 - 5 > 10 else y1 + 15
cv2.putText(self.image, class_names[int(class_index)], (x1, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color, 1)

# 绘制骨架
for i, (idx1, idx2) in enumerate(skeleton):
# 获取两个关键点的坐标和可见性
# 注意: 某些数据集可能 keypoint 长度不一致或者索引越界,这里最好做个检查,或者假设输入合法
if idx1 < len(keypoint) and idx2 < len(keypoint):
x1, y1, v1 = keypoint[idx1]
x2, y2, v2 = keypoint[idx2]

# 如果关键点不可见(v=0), 则不绘制骨架(可选策略)
if v1 > 0 and v2 > 0:
pos1 = (int(x1 * self.img_w), int(y1 * self.img_h))
pos2 = (int(x2 * self.img_w), int(y2 * self.img_h))
# 使用随机生成的颜色,确保每条线颜色不同
line_color = [int(c) for c in sk_colors[i]]
cv2.line(self.image, pos1, pos2, line_color, 2)

# 绘制关键点
for i, (px, py, visibility) in enumerate(keypoint):
color = [int(c) for c in kp_colors[i % len(kp_colors)]] # 设置关键点颜色
# 设置可见点为圈, 不可见点为×
if visibility == 2:
cv2.circle(self.image, (int(px * self.img_w), int(py * self.img_h)), 3, color, -1)
else:
cv2.putText(self.image, 'x', (int(px * self.img_w), int(py * self.img_h)),
cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
# 在关键点位置显示关键点名称
# 在关键点位置显示关键点名称
kpt_name = keypoint_names[i] if i < len(keypoint_names) else str(i)
cv2.putText(self.image, kpt_name, (int(px * self.img_w), int(py * self.img_h)),
cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)

def draw_all_objects(self, class_names, keypoint_names, skeleton):
"""
绘制所有对象的关键点和骨架
传入参数:
keypoint_names: 关键点名称, 形状为 [keypoint_name]
skeleton: 骨架连接关系, 形状为 [(keypoint_index1, keypoint_index2)]
"""
for obj in self.labels:
self.draw_one_object(class_names, obj, keypoint_names, skeleton)
return self.image

4.5 图像展示

1
2
3
4
5
6
7
8
class ShowImage:
def __init__(self, image):
self.image = image

def show(self):
cv2.imshow('image', self.image)
cv2.waitKey(0)
cv2.destroyAllWindows()

4.6 主函数

1
2
3
4
5
6
7
if __name__ == '__main__':
image = ReadimgLabel.get_img(image_path) # 读取图片
labels = ReadimgLabel.get_labels(label_path) # 读取标签
objects = ReadimgLabel.format_labels(labels) # 格式化标签
drawer = KeyPointDrawer(image, objects).draw_all_objects(class_names, keypoint_names, skeleton)
ShowImage(drawer).show()