Skip to content

Commit a2d1871

Browse files
authored
Merge pull request #207 from luxonis/feat/output-shapes-and-layout
Get shape for each output from the ONNX and infer layout + YOLO26 changes + output field re-names
2 parents d314ab1 + cddb0f6 commit a2d1871

11 files changed

Lines changed: 319 additions & 35 deletions

File tree

tests/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
{"name": "yolo26x", "version": "v26"},
7777
{"name": "yolo26n-seg", "version": "v26"},
7878
{"name": "yolo26n-pose", "version": "v26"},
79+
{"name": "yolo26n", "version": "v26_nms", "cli_version": "yolov26_nms"},
7980
{"name": "yolov8n-cls", "version": "v8"},
8081
{"name": "yolov8n-seg", "version": "v8"},
8182
{"name": "yolov8n-pose", "version": "v8"},

tests/helper_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,33 @@ def nn_archive_checker(extra_keys_to_check: list = []): # noqa: B006
127127
assert temp_cfg[keys[-1]] == target, (
128128
f"Value `{temp_cfg[keys[-1]]}` at key `{keys}` doesn't match expected value `{target}`"
129129
)
130+
131+
132+
def load_latest_nn_archive_config() -> dict:
133+
"""Load config.json from the most recently exported NNArchive."""
134+
output_dir = "shared_with_container/outputs"
135+
subdirs = [
136+
d for d in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, d))
137+
]
138+
assert subdirs, f"No folders found in `{output_dir}`"
139+
140+
subdirs.sort(key=lambda d: os.path.getmtime(os.path.join(output_dir, d)))
141+
latest_subdir = subdirs[-1]
142+
model_output_path = os.path.join(output_dir, latest_subdir)
143+
144+
archive_files = [f for f in os.listdir(model_output_path) if f.endswith(".tar.xz")]
145+
assert len(archive_files) == 1, (
146+
f"Expected 1 .tar.xz file, found {len(archive_files)}: {archive_files}"
147+
)
148+
archive_path = os.path.join(model_output_path, archive_files[0])
149+
150+
with tarfile.open(archive_path, "r:xz") as tar:
151+
file_names = [m.name for m in tar.getmembers() if m.isfile()]
152+
config_files = [name for name in file_names if name.endswith("config.json")]
153+
assert len(config_files) == 1, (
154+
f"Expected 1 config.json file, found {len(config_files)}: {config_files}"
155+
)
156+
config_member = tar.getmember(config_files[0])
157+
config_file = tar.extractfile(config_member)
158+
assert config_file is not None, "Failed to extract config.json"
159+
return json.load(config_file)

tests/nnarchive_output_checks.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from __future__ import annotations
2+
3+
from copy import deepcopy
4+
5+
V8_DETECTION_CHECK = {
6+
"name": "yolov8n",
7+
"version": "v8",
8+
"model_outputs": ["output1_yolov6r2", "output2_yolov6r2", "output3_yolov6r2"],
9+
"head_outputs": ["output1_yolov6r2", "output2_yolov6r2", "output3_yolov6r2"],
10+
"yolo_outputs": ["output1_yolov6r2", "output2_yolov6r2", "output3_yolov6r2"],
11+
}
12+
13+
V8_SEG_CHECK = {
14+
"name": "yolov8n-seg",
15+
"version": "v8",
16+
"model_outputs": [
17+
"output1_yolov8",
18+
"output2_yolov8",
19+
"output3_yolov8",
20+
"output1_masks",
21+
"output2_masks",
22+
"output3_masks",
23+
"protos_output",
24+
],
25+
"head_outputs": [
26+
"output1_yolov8",
27+
"output2_yolov8",
28+
"output3_yolov8",
29+
"output1_masks",
30+
"output2_masks",
31+
"output3_masks",
32+
"protos_output",
33+
],
34+
"yolo_outputs": ["output1_yolov8", "output2_yolov8", "output3_yolov8"],
35+
"mask_outputs": ["output1_masks", "output2_masks", "output3_masks"],
36+
}
37+
38+
V8_POSE_CHECK = {
39+
"name": "yolov8n-pose",
40+
"version": "v8",
41+
"model_outputs": [
42+
"output1_yolov8",
43+
"output2_yolov8",
44+
"output3_yolov8",
45+
"kpt_output1",
46+
"kpt_output2",
47+
"kpt_output3",
48+
],
49+
"head_outputs": [
50+
"output1_yolov8",
51+
"output2_yolov8",
52+
"output3_yolov8",
53+
"kpt_output1",
54+
"kpt_output2",
55+
"kpt_output3",
56+
],
57+
"yolo_outputs": ["output1_yolov8", "output2_yolov8", "output3_yolov8"],
58+
"keypoints_outputs": ["kpt_output1", "kpt_output2", "kpt_output3"],
59+
}
60+
61+
62+
def _clone_check(base_case: dict, *, name: str, version: str) -> dict:
63+
case = deepcopy(base_case)
64+
case["name"] = name
65+
case["version"] = version
66+
return case
67+
68+
69+
N_VARIANT_OUTPUT_NAME_CHECKS = [
70+
V8_DETECTION_CHECK,
71+
V8_SEG_CHECK,
72+
V8_POSE_CHECK,
73+
_clone_check(V8_DETECTION_CHECK, name="yolov9t", version="v9"),
74+
_clone_check(V8_DETECTION_CHECK, name="yolov11n", version="v11"),
75+
_clone_check(V8_SEG_CHECK, name="yolov11n-seg", version="v11"),
76+
_clone_check(V8_POSE_CHECK, name="yolov11n-pose", version="v11"),
77+
_clone_check(V8_DETECTION_CHECK, name="yolov12n", version="v12"),
78+
{
79+
"name": "yolo26n",
80+
"version": "v26",
81+
"model_outputs": ["output_yolo26"],
82+
"head_outputs": ["output_yolo26"],
83+
"yolo_outputs": ["output_yolo26"],
84+
},
85+
{
86+
"name": "yolo26n-seg",
87+
"version": "v26",
88+
"model_outputs": ["output_yolo26", "output_masks", "protos_output"],
89+
"head_outputs": ["output_yolo26", "output_masks", "protos_output"],
90+
"yolo_outputs": ["output_yolo26"],
91+
"mask_outputs": ["output_masks"],
92+
},
93+
{
94+
"name": "yolo26n-pose",
95+
"version": "v26",
96+
"model_outputs": ["output_yolo26", "kpt_output"],
97+
"head_outputs": ["output_yolo26", "kpt_output"],
98+
"yolo_outputs": ["output_yolo26"],
99+
"keypoints_outputs": ["kpt_output"],
100+
},
101+
]

tests/test_end2end.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
import pytest
88
from constants import PRIVATE_TEST_MODELS, SAVE_FOLDER, TEST_MODELS
9-
from helper_functions import download_model, download_private_model, nn_archive_checker
9+
from helper_functions import (
10+
download_model,
11+
download_private_model,
12+
load_latest_nn_archive_config,
13+
nn_archive_checker,
14+
)
15+
from nnarchive_output_checks import N_VARIANT_OUTPUT_NAME_CHECKS
1016

1117
logger = logging.getLogger()
1218
logger.setLevel(logging.INFO)
@@ -15,7 +21,12 @@
1521
@pytest.mark.parametrize(
1622
"model",
1723
TEST_MODELS,
18-
ids=[model["name"] for model in TEST_MODELS],
24+
ids=[
25+
model.get("cli_version", model["name"])
26+
if model.get("cli_version")
27+
else model["name"]
28+
for model in TEST_MODELS
29+
],
1930
)
2031
def test_cli_conversion(model: dict, test_config: dict, subtests):
2132
"""Tests the whole CLI conversion flow with no extra params specified."""
@@ -50,6 +61,8 @@ def test_cli_conversion(model: dict, test_config: dict, subtests):
5061
pytest.skip("Weights not present and `download_weights` not set")
5162

5263
command = ["tools", model_path]
64+
if model.get("cli_version"):
65+
command += ["--version", model.get("cli_version")]
5366
if model.get("size"): # edge case when stride=64 is needed
5467
command += ["--imgsz", model.get("size")]
5568

@@ -79,6 +92,65 @@ def test_cli_conversion(model: dict, test_config: dict, subtests):
7992
nn_archive_checker(extra_keys_to_check=extra_keys_to_check)
8093

8194

95+
@pytest.mark.parametrize(
96+
"model_case",
97+
N_VARIANT_OUTPUT_NAME_CHECKS,
98+
ids=[model_case["name"] for model_case in N_VARIANT_OUTPUT_NAME_CHECKS],
99+
)
100+
def test_n_variant_nnarchive_outputs(model_case: dict, test_config: dict):
101+
"""Checks NNArchive output-related fields for selected variants."""
102+
if (
103+
test_config["test_case"] is not None
104+
and model_case["name"] != test_config["test_case"]
105+
):
106+
pytest.skip(
107+
f"Test case ({model_case['name']}) doesn't match selected test case ({test_config['test_case']})"
108+
)
109+
110+
if (
111+
test_config["yolo_version"] is not None
112+
and model_case["version"] != test_config["yolo_version"]
113+
):
114+
pytest.skip(
115+
f"Model version ({model_case['version']}) doesn't match selected version ({test_config['yolo_version']})."
116+
)
117+
118+
model_path = os.path.join(SAVE_FOLDER, f"{model_case['name']}.pt")
119+
if not os.path.exists(model_path):
120+
if test_config["download_weights"]:
121+
model_path = download_model(model_case["name"], SAVE_FOLDER)
122+
else:
123+
pytest.skip("Weights missing and `download_weights` not set")
124+
125+
command = ["tools", model_path]
126+
result = subprocess.run(
127+
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
128+
)
129+
if result.returncode != 0:
130+
pytest.fail(f"Exit code: {result.returncode}, Output: {result.stdout}")
131+
132+
cfg = load_latest_nn_archive_config()
133+
output_names = [output["name"] for output in cfg["model"]["outputs"]]
134+
head = cfg["model"]["heads"][0]
135+
metadata = head["metadata"]
136+
head_output_names = head["outputs"]
137+
yolo_output_names = metadata["yolo_outputs"] or []
138+
mask_output_names = metadata["mask_outputs"] or []
139+
keypoint_output_names = metadata["keypoints_outputs"] or []
140+
141+
for key, actual in [
142+
("model_outputs", output_names),
143+
("head_outputs", head_output_names),
144+
("yolo_outputs", yolo_output_names),
145+
("mask_outputs", mask_output_names),
146+
("keypoints_outputs", keypoint_output_names),
147+
]:
148+
for expected_name in model_case.get(key, []):
149+
assert expected_name in actual, (
150+
f"{key}: expected `{expected_name}` for {model_case['name']}, got {actual}"
151+
)
152+
153+
82154
@pytest.mark.parametrize(
83155
"model",
84156
PRIVATE_TEST_MODELS,

tools/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
YOLOV11_CONVERSION,
2929
YOLOV12_CONVERSION,
3030
YOLOV26_CONVERSION,
31+
YOLOV26_NMS_CONVERSION,
3132
detect_version,
3233
)
3334

@@ -50,6 +51,7 @@
5051
YOLOV11_CONVERSION,
5152
YOLOV12_CONVERSION,
5253
YOLOV26_CONVERSION,
54+
YOLOV26_NMS_CONVERSION,
5355
]
5456

5557

@@ -176,6 +178,7 @@ def convert(
176178
YOLOV9_CONVERSION,
177179
YOLOV11_CONVERSION,
178180
YOLOV12_CONVERSION,
181+
YOLOV26_NMS_CONVERSION,
179182
]:
180183
from tools.yolo.yolov8_exporter import YoloV8Exporter
181184

tools/modules/exporter.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
from datetime import datetime
5-
from typing import List, Optional, Tuple
5+
from typing import Any, Dict, List, Optional, Tuple
66

77
import onnx
88
import onnxsim
@@ -101,6 +101,44 @@ def export_onnx(self) -> os.PathLike:
101101

102102
return self.f_onnx
103103

104+
@staticmethod
105+
def _infer_layout_from_shape(shape: List[Any]) -> Optional[str]:
106+
rank = len(shape)
107+
if rank == 4:
108+
return "NCHW"
109+
if rank == 3:
110+
return "NCD"
111+
if rank == 2:
112+
return "NC"
113+
if rank == 1:
114+
return "C"
115+
return None
116+
117+
def get_output_specs(self) -> Dict[str, Dict[str, Any]]:
118+
"""Collect output shape and layout for all ONNX outputs by name."""
119+
if self.f_onnx is None:
120+
raise RuntimeError("ONNX must be exported before reading output specs.")
121+
122+
model_onnx = onnx.load(self.f_onnx)
123+
specs: Dict[str, Dict[str, Any]] = {}
124+
125+
for output in model_onnx.graph.output:
126+
shape: List[Any] = []
127+
for dim in output.type.tensor_type.shape.dim:
128+
if dim.HasField("dim_value"):
129+
shape.append(int(dim.dim_value))
130+
elif dim.HasField("dim_param") and dim.dim_param:
131+
shape.append(dim.dim_param)
132+
else:
133+
shape.append(None)
134+
135+
specs[output.name] = {
136+
"shape": shape,
137+
"layout": self._infer_layout_from_shape(shape),
138+
}
139+
140+
return specs
141+
104142
def make_nn_archive(
105143
self,
106144
class_list: List[str],
@@ -144,6 +182,7 @@ def make_nn_archive(
144182

145183
if output_kwargs is None:
146184
output_kwargs = {}
185+
output_specs = self.get_output_specs()
147186

148187
archive = ArchiveGenerator(
149188
archive_name=self.model_name,
@@ -172,6 +211,8 @@ def make_nn_archive(
172211
{
173212
"name": output,
174213
"dtype": DataType.FLOAT32,
214+
"shape": output_specs.get(output, {}).get("shape"),
215+
"layout": output_specs.get(output, {}).get("layout"),
175216
}
176217
for output in self.all_output_names
177218
],

0 commit comments

Comments
 (0)