diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index 6d4a3d3..41c6322 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -107,21 +107,45 @@ end ---@param python_command string[] ---@param config neotest-python._AdapterConfig ---@param runner string ----@return string -local function scan_test_function_pattern(runner, config, python_command) +---@return table {test_pattern: string, namespace_pattern: string} +local function scan_pytest_config(runner, config, python_command) local test_function_pattern = "^test" + local namespace_pattern = "" -- For describe_prefixes + if runner == "pytest" and config.pytest_discovery then local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" }) local _, data = lib.process.run(cmd, { stdout = true, stderr = true }) for line in vim.gsplit(data.stdout, "\n", true) do - if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then + if string.sub(line, 1, 1) == "{" then local pytest_option = vim.json.decode(line) - test_function_pattern = pytest_option.python_functions + + -- Extract python_functions pattern + if pytest_option.python_functions then + test_function_pattern = pytest_option.python_functions + end + + -- Extract describe_prefixes pattern (from pytest-describe plugin) + if pytest_option.describe_prefixes then + local prefixes = vim.split(pytest_option.describe_prefixes, " ", { trimempty = true }) + local prefix_patterns = vim.tbl_map(function(p) + return "^" .. p .. "_" + end, prefixes) + namespace_pattern = table.concat(prefix_patterns, "|") + end end end end - return test_function_pattern + + -- Default namespace patterns if none configured + if namespace_pattern == "" then + namespace_pattern = "^(describe_|context_|when_|given_|scenario_|requirement_)" + end + + return { + test_pattern = test_function_pattern, + namespace_pattern = namespace_pattern, + } end ---@param python_command string[] @@ -129,9 +153,19 @@ end ---@param runner string ---@return string M.treesitter_queries = function(runner, config, python_command) - local test_function_pattern = scan_test_function_pattern(runner, config, python_command) + local patterns = scan_pytest_config(runner, config, python_command) + local test_function_pattern = patterns.test_pattern + local namespace_pattern = patterns.namespace_pattern + return string.format([[ - ;; Match undecorated functions + ;; Match container functions (describe_*, context_*, when_*, given_*, scenario_*, requirement_*) + ;; These create namespaces for organizing tests + ((function_definition + name: (identifier) @namespace.name) + (#match? @namespace.name "%s")) + @namespace.definition + + ;; Match undecorated test functions ((function_definition name: (identifier) @test.name) (#match? @test.name "%s")) @@ -148,7 +182,7 @@ M.treesitter_queries = function(runner, config, python_command) (decorated_definition (class_definition name: (identifier) @namespace.name)) - @namespace.definition + @namespace.definition ;; Match undecorated classes: namespaces nest so #not-has-parent is used ;; to ensure each namespace is annotated only once @@ -158,7 +192,7 @@ M.treesitter_queries = function(runner, config, python_command) @namespace.definition (#not-has-parent? @namespace.definition decorated_definition) ) - ]], test_function_pattern, test_function_pattern) + ]], namespace_pattern, test_function_pattern, test_function_pattern) end M.get_root = diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index 57e3e27..ca73508 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -240,8 +240,14 @@ def maybe_debugpy_postmortem(excinfo): class TestNameTemplateExtractor: @staticmethod def pytest_collection_modifyitems(config): - config = {"python_functions": config.getini("python_functions")[0]} - print(f"\n{json.dumps(config)}\n") + extracted_config = {"python_functions": config.getini("python_functions")[0]} + + # Extract describe_prefixes if pytest-describe is configured + describe_prefixes = config.getini("describe_prefixes") + if describe_prefixes: + extracted_config["describe_prefixes"] = " ".join(describe_prefixes) + + print(f"\n{json.dumps(extracted_config)}\n") def extract_test_name_template(args) -> int: