diff --git a/get_pytest_options.py b/get_pytest_options.py new file mode 100644 index 0000000..f89eab8 --- /dev/null +++ b/get_pytest_options.py @@ -0,0 +1,9 @@ +import pytest + + +def main(): + pytest.main(args=['-k', 'neotest_none'], plugins=['neotest_python.pytest']) + + +if __name__ == "__main__": + main() diff --git a/lua/neotest-python/init.lua b/lua/neotest-python/init.lua index 8d9a750..811a068 100644 --- a/lua/neotest-python/init.lua +++ b/lua/neotest-python/init.lua @@ -3,15 +3,23 @@ local lib = require("neotest.lib") local base = require("neotest-python.base") local pytest = require("neotest-python.pytest") -local function get_script() - local paths = vim.api.nvim_get_runtime_file("neotest.py", true) +local function get_python_script(filename) + local paths = vim.api.nvim_get_runtime_file(filename, true) for _, path in ipairs(paths) do - if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then + if vim.endswith(path, ("neotest-python%s%s"):format(lib.files.sep, filename)) then return path end end - error("neotest.py not found") + error(string.format("%s not found", filename)) +end + +local function get_config_loading_script() + return get_python_script("get_pytest_options.py") +end + +local function get_main_script() + return get_python_script("neotest.py") end local dap_args @@ -85,25 +93,43 @@ function PythonNeotestAdapter.filter_dir(name) return name ~= "venv" end +function PythonNeotestAdapter.init_python_functions(python, runner) + if PythonNeotestAdapter.python_functions == nil then + local python_functions = "^test" + if runner == "pytest" and pytest_discover_instances then + local cmd = vim.tbl_flatten({ python, get_config_loading_script() }) + local _, data = lib.process.run(cmd, { stdout = true, stderr = true }) + + for line in vim.gsplit(data.stdout, "\n", true) do + if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then + local config = vim.json.decode(line) + python_functions = config.python_functions + end + end + end + PythonNeotestAdapter.python_functions = python_functions + end +end + ---@async ---@return neotest.Tree | nil function PythonNeotestAdapter.discover_positions(path) local root = PythonNeotestAdapter.root(path) or vim.loop.cwd() local python = get_python(root) local runner = get_runner(python) + PythonNeotestAdapter.init_python_functions(python, runner) - -- Parse the file while pytest is running - local query = [[ + local query = string.format([[ ;; Match undecorated functions ((function_definition name: (identifier) @test.name) - (#match? @test.name "^test")) + (#match? @test.name "%s")) @test.definition ;; Match decorated function, including decorators in definition (decorated_definition ((function_definition name: (identifier) @test.name) - (#match? @test.name "^test"))) + (#match? @test.name "%s"))) @test.definition ;; Match decorated classes, including decorators in definition @@ -119,13 +145,13 @@ function PythonNeotestAdapter.discover_positions(path) @namespace.definition (#not-has-parent? @namespace.definition decorated_definition) ) - ]] + ]], PythonNeotestAdapter.python_functions, PythonNeotestAdapter.python_functions) local positions = lib.treesitter.parse_positions(path, query, { require_namespaces = runner == "unittest", }) if runner == "pytest" and pytest_discover_instances then - pytest.augment_positions(python, get_script(), path, positions, root) + pytest.augment_positions(python, get_main_script(), path, positions, root) end return positions @@ -165,7 +191,7 @@ function PythonNeotestAdapter.build_spec(args) if position then table.insert(script_args, position.id) end - local python_script = get_script() + local python_script = get_main_script() local command = vim.tbl_flatten({ python, python_script, diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index a34454c..359db36 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -1,3 +1,4 @@ +import json from io import StringIO from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -9,6 +10,11 @@ from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus +def pytest_collection_modifyitems(config): + config = {"python_functions": config.getini("python_functions")[0]} + print(f"\n{json.dumps(config)}\n") + + class PytestNeotestAdapter(NeotestAdapter): def __init__(self, emit_parameterized_ids: bool): self.emit_parameterized_ids = emit_parameterized_ids