Skip to content

Commit

Permalink
Extract python_functions & use it in treesitter query
Browse files Browse the repository at this point in the history
  • Loading branch information
nicos68 committed Feb 11, 2024
1 parent 2e83d2b commit a485965
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 11 deletions.
9 changes: 9 additions & 0 deletions get_pytest_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest


def main():
pytest.main(args=['-k', 'neotest_none'], plugins=['neotest_python.pytest'])


if __name__ == "__main__":
main()
48 changes: 37 additions & 11 deletions lua/neotest-python/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@ local lib = require("neotest.lib")
local base = require("neotest-python.base")
local pytest = require("neotest-python.pytest")

local function get_script()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
local function get_python_script(filename)
local paths = vim.api.nvim_get_runtime_file(filename, true)
for _, path in ipairs(paths) do
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
if vim.endswith(path, ("neotest-python%s%s"):format(lib.files.sep, filename)) then
return path
end
end

error("neotest.py not found")
error(string.format("%s not found", filename))
end

local function get_config_loading_script()
return get_python_script("get_pytest_options.py")
end

local function get_main_script()
return get_python_script("neotest.py")
end

local dap_args
Expand Down Expand Up @@ -85,25 +93,43 @@ function PythonNeotestAdapter.filter_dir(name)
return name ~= "venv"
end

function PythonNeotestAdapter.init_python_functions(python, runner)
if PythonNeotestAdapter.python_functions == nil then
local python_functions = "^test"
if runner == "pytest" and pytest_discover_instances then
local cmd = vim.tbl_flatten({ python, get_config_loading_script() })
local _, data = lib.process.run(cmd, { stdout = true, stderr = true })

for line in vim.gsplit(data.stdout, "\n", true) do
if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then
local config = vim.json.decode(line)
python_functions = config.python_functions
end
end
end
PythonNeotestAdapter.python_functions = python_functions
end
end

---@async
---@return neotest.Tree | nil
function PythonNeotestAdapter.discover_positions(path)
local root = PythonNeotestAdapter.root(path) or vim.loop.cwd()
local python = get_python(root)
local runner = get_runner(python)
PythonNeotestAdapter.init_python_functions(python, runner)

-- Parse the file while pytest is running
local query = [[
local query = string.format([[
;; Match undecorated functions
((function_definition
name: (identifier) @test.name)
(#match? @test.name "^test"))
(#match? @test.name "%s"))
@test.definition
;; Match decorated function, including decorators in definition
(decorated_definition
((function_definition
name: (identifier) @test.name)
(#match? @test.name "^test")))
(#match? @test.name "%s")))
@test.definition
;; Match decorated classes, including decorators in definition
Expand All @@ -119,13 +145,13 @@ function PythonNeotestAdapter.discover_positions(path)
@namespace.definition
(#not-has-parent? @namespace.definition decorated_definition)
)
]]
]], PythonNeotestAdapter.python_functions, PythonNeotestAdapter.python_functions)
local positions = lib.treesitter.parse_positions(path, query, {
require_namespaces = runner == "unittest",
})

if runner == "pytest" and pytest_discover_instances then
pytest.augment_positions(python, get_script(), path, positions, root)
pytest.augment_positions(python, get_main_script(), path, positions, root)
end

return positions
Expand Down Expand Up @@ -165,7 +191,7 @@ function PythonNeotestAdapter.build_spec(args)
if position then
table.insert(script_args, position.id)
end
local python_script = get_script()
local python_script = get_main_script()
local command = vim.tbl_flatten({
python,
python_script,
Expand Down
6 changes: 6 additions & 0 deletions neotest_python/pytest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
Expand All @@ -9,6 +10,11 @@
from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus


def pytest_collection_modifyitems(config):
config = {"python_functions": config.getini("python_functions")[0]}
print(f"\n{json.dumps(config)}\n")


class PytestNeotestAdapter(NeotestAdapter):
def __init__(self, emit_parameterized_ids: bool):
self.emit_parameterized_ids = emit_parameterized_ids
Expand Down

0 comments on commit a485965

Please sign in to comment.