|
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | from constants import PRIVATE_TEST_MODELS, SAVE_FOLDER, TEST_MODELS |
9 | | -from helper_functions import download_model, download_private_model, nn_archive_checker |
| 9 | +from helper_functions import ( |
| 10 | + download_model, |
| 11 | + download_private_model, |
| 12 | + load_latest_nn_archive_config, |
| 13 | + nn_archive_checker, |
| 14 | +) |
| 15 | +from nnarchive_output_checks import N_VARIANT_OUTPUT_NAME_CHECKS |
10 | 16 |
|
11 | 17 | logger = logging.getLogger() |
12 | 18 | logger.setLevel(logging.INFO) |
|
15 | 21 | @pytest.mark.parametrize( |
16 | 22 | "model", |
17 | 23 | TEST_MODELS, |
18 | | - ids=[model["name"] for model in TEST_MODELS], |
| 24 | + ids=[ |
| 25 | + model.get("cli_version", model["name"]) |
| 26 | + if model.get("cli_version") |
| 27 | + else model["name"] |
| 28 | + for model in TEST_MODELS |
| 29 | + ], |
19 | 30 | ) |
20 | 31 | def test_cli_conversion(model: dict, test_config: dict, subtests): |
21 | 32 | """Tests the whole CLI conversion flow with no extra params specified.""" |
@@ -50,6 +61,8 @@ def test_cli_conversion(model: dict, test_config: dict, subtests): |
50 | 61 | pytest.skip("Weights not present and `download_weights` not set") |
51 | 62 |
|
52 | 63 | command = ["tools", model_path] |
| 64 | + if model.get("cli_version"): |
| 65 | + command += ["--version", model.get("cli_version")] |
53 | 66 | if model.get("size"): # edge case when stride=64 is needed |
54 | 67 | command += ["--imgsz", model.get("size")] |
55 | 68 |
|
@@ -79,6 +92,65 @@ def test_cli_conversion(model: dict, test_config: dict, subtests): |
79 | 92 | nn_archive_checker(extra_keys_to_check=extra_keys_to_check) |
80 | 93 |
|
81 | 94 |
|
| 95 | +@pytest.mark.parametrize( |
| 96 | + "model_case", |
| 97 | + N_VARIANT_OUTPUT_NAME_CHECKS, |
| 98 | + ids=[model_case["name"] for model_case in N_VARIANT_OUTPUT_NAME_CHECKS], |
| 99 | +) |
| 100 | +def test_n_variant_nnarchive_outputs(model_case: dict, test_config: dict): |
| 101 | + """Checks NNArchive output-related fields for selected variants.""" |
| 102 | + if ( |
| 103 | + test_config["test_case"] is not None |
| 104 | + and model_case["name"] != test_config["test_case"] |
| 105 | + ): |
| 106 | + pytest.skip( |
| 107 | + f"Test case ({model_case['name']}) doesn't match selected test case ({test_config['test_case']})" |
| 108 | + ) |
| 109 | + |
| 110 | + if ( |
| 111 | + test_config["yolo_version"] is not None |
| 112 | + and model_case["version"] != test_config["yolo_version"] |
| 113 | + ): |
| 114 | + pytest.skip( |
| 115 | + f"Model version ({model_case['version']}) doesn't match selected version ({test_config['yolo_version']})." |
| 116 | + ) |
| 117 | + |
| 118 | + model_path = os.path.join(SAVE_FOLDER, f"{model_case['name']}.pt") |
| 119 | + if not os.path.exists(model_path): |
| 120 | + if test_config["download_weights"]: |
| 121 | + model_path = download_model(model_case["name"], SAVE_FOLDER) |
| 122 | + else: |
| 123 | + pytest.skip("Weights missing and `download_weights` not set") |
| 124 | + |
| 125 | + command = ["tools", model_path] |
| 126 | + result = subprocess.run( |
| 127 | + command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True |
| 128 | + ) |
| 129 | + if result.returncode != 0: |
| 130 | + pytest.fail(f"Exit code: {result.returncode}, Output: {result.stdout}") |
| 131 | + |
| 132 | + cfg = load_latest_nn_archive_config() |
| 133 | + output_names = [output["name"] for output in cfg["model"]["outputs"]] |
| 134 | + head = cfg["model"]["heads"][0] |
| 135 | + metadata = head["metadata"] |
| 136 | + head_output_names = head["outputs"] |
| 137 | + yolo_output_names = metadata["yolo_outputs"] or [] |
| 138 | + mask_output_names = metadata["mask_outputs"] or [] |
| 139 | + keypoint_output_names = metadata["keypoints_outputs"] or [] |
| 140 | + |
| 141 | + for key, actual in [ |
| 142 | + ("model_outputs", output_names), |
| 143 | + ("head_outputs", head_output_names), |
| 144 | + ("yolo_outputs", yolo_output_names), |
| 145 | + ("mask_outputs", mask_output_names), |
| 146 | + ("keypoints_outputs", keypoint_output_names), |
| 147 | + ]: |
| 148 | + for expected_name in model_case.get(key, []): |
| 149 | + assert expected_name in actual, ( |
| 150 | + f"{key}: expected `{expected_name}` for {model_case['name']}, got {actual}" |
| 151 | + ) |
| 152 | + |
| 153 | + |
82 | 154 | @pytest.mark.parametrize( |
83 | 155 | "model", |
84 | 156 | PRIVATE_TEST_MODELS, |
|
0 commit comments