Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: detect subtests #1

Merged
merged 7 commits into from
Jun 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions lua/neotest-go/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ local function sanitize_output(output)
return output:gsub('\n', ''):gsub('\t', '')
end

-- replace whitespace with underscores and remove surrounding quotes
local function transform_test_name(name)
return name:gsub('[%s]', '_'):gsub('^"(.*)"$', '%1')
end

---Get a line in a buffer, defaulting to the first if none is specified
---@param buf number
---@param nr number?
Expand Down Expand Up @@ -126,25 +131,68 @@ function adapter.is_test_file(file_path)
return is_test
end

---@param position neotest.Position The position to return an ID for
---@param namespaces neotest.Position[] Any namespaces the position is within
local function generate_position_id(position, namespaces)
local prefix = {}
for _, namespace in ipairs(namespaces) do
if namespace.type ~= 'file' then
table.insert(prefix, namespace.name)
end
end
local name = transform_test_name(position.name)
return table.concat(vim.tbl_flatten({ position.path, prefix, name }), '::')
end

---@async
---@return neotest.Tree| nil
function adapter.discover_positions(path)
-- NOTE: this is a query to annotate a test function as a namespace
-- ((function_declaration
-- name: (identifier) @namespace.name
-- body: (block
-- (call_expression
-- function: (selector_expression
-- field: (field_identifier) @_method)
-- (#match? @_method "^Run"))))
-- (#match? @namespace.name "^Test"))
-- @namespace.definition
local query = [[
((function_declaration
name: (identifier) @test.name)
(#match? @test.name "^Test"))
@test.definition

(call_expression
function: (selector_expression
field: (field_identifier) @test.method)
(#match? @test.method "^Run$")
arguments: (argument_list . (_) @test.name))
@test.definition

(package_clause
(package_identifier) @namespace.name)
@namespace.definition
]]
return lib.treesitter.parse_positions(path, query, {
require_namespaces = false,
nested_tests = true,
position_id = generate_position_id,
})
end

---@param tree neotest.Tree
---@param name string
---@return string
local function get_prefix(tree, name)
local parent_tree = tree:parent()
if not parent_tree or parent_tree:data().type == 'file' then
return name
end
local parent_name = parent_tree:data().name
return parent_name .. '/' .. name
end

---@async
---@param args neotest.RunArgs
---@return neotest.RunSpec
Expand All @@ -158,7 +206,7 @@ function adapter.build_spec(args)
dir = { dir .. '/...' },
file = { position.path },
namespace = { package },
test = { '-run', position.name .. '$', dir },
test = { '-run', get_prefix(args.tree, position.name) .. '$', dir },
})[position.type]

local command = { 'go', 'test', '-v', '-json', unpack(cmd_args) }
Expand All @@ -168,7 +216,7 @@ function adapter.build_spec(args)
end

return {
command = command,
command = table.concat(command, ' '),
context = {
results_path = results_path,
file = position.path,
Expand All @@ -195,14 +243,17 @@ function adapter.results(_, result, tree)
empty_result_fname = async.fn.tempname()
fn.writefile(tests.__unnamed.output, empty_result_fname)
end
for _, value in tree:iter() do
for _, node in tree:iter_nodes() do
local value = node:data()
if no_results then
results[value.id] = {
status = 'skipped',
output = empty_result_fname,
}
else
local test_output = tests[value.name]
local id_parts = vim.split(value.id, '::')
table.remove(id_parts, 1)
local test_output = tests[table.concat(id_parts, '/')]
if test_output then
local fname = async.fn.tempname()
fn.writefile(test_output.output, fname)
Expand Down
9 changes: 9 additions & 0 deletions neotest_go/cases.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package main

func add(a, b int) int {
return a + b
}

func subtract(a, b int) int {
return a - b
}
49 changes: 49 additions & 0 deletions neotest_go/cases_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package main

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSubtract(t *testing.T) {
testCases := []struct {
desc string
a int
b int
want int
}{
{
desc: "test one",
a: 1,
b: 2,
want: 3,
},
{
desc: "test two",
a: 1,
b: 2,
want: 7,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
assert.Equal(t, tC.want, subtract(tC.a, tC.b))
})
}
}

func TestAdd(t *testing.T) {
t.Run("test one", func(t *testing.T) {
assert.Equal(t, 3, add(1, 2))
})

t.Run("test two", func(t *testing.T) {
assert.Equal(t, 5, add(1, 2))
})

variable := "string"
t.Run(variable, func(t *testing.T) {
assert.Equal(t, 3, add(1, 2))
})
}