From 46b41d13d6943443c20b3bf87fdf8eb495fee4c2 Mon Sep 17 00:00:00 2001 From: Folke Lemaitre Date: Mon, 12 Jun 2023 23:27:21 +0200 Subject: [PATCH] feat: added ffi based searcher. Finally 100% correct end pos for matches --- lua/flash/searcher.lua | 94 ++++++++++++++++++++++++++++++++-- tests/search/searcher_spec.lua | 30 +++++++++++ 2 files changed, 119 insertions(+), 5 deletions(-) diff --git a/lua/flash/searcher.lua b/lua/flash/searcher.lua index c1e4bf0..5f76fdc 100644 --- a/lua/flash/searcher.lua +++ b/lua/flash/searcher.lua @@ -1,4 +1,6 @@ local M = {} +---@type "vim" | "regex" | "ffi" +M.api = "ffi" ---@class Flash.Match ---@field win window @@ -17,6 +19,25 @@ function M.search(win, state) end) end +---@type ffi.namespace* +local C + +---@param from number[] +function M.get_match_end(from) + if not C then + local ffi = require("ffi") + ffi.cdef([[ + int search_match_endcol; + unsigned int search_match_lines; + ]]) + C = ffi.C + end + return { + from[1] + C.search_match_lines, + math.max(0, C.search_match_endcol - 1), + } +end + ---@param state Flash.State ---@return Flash.Match[] function M._search(win, state) @@ -65,6 +86,10 @@ end ---@param flags? string ---@param k? number function M.get_matches(pattern, flags, k) + if M.api ~= "ffi" then + return M.get_matches_old(pattern, flags, k) + end + local view = vim.fn.winsaveview() flags = flags or "" @@ -72,7 +97,51 @@ function M.get_matches(pattern, flags, k) ---@type Flash.Match[] local matches = {} - local ok, re = pcall(vim.regex, pattern .. (vim.go.ignorecase and "\\c" or "")) + local count = vim.fn.searchcount({ pattern = pattern, recompute = true, maxcount = k }).total or 0 + + local function next(f, p) + local from = vim.fn.searchpos(p or pattern, f) + return from[1] ~= 0 and { from[1], from[2] - 1 } or nil + end + + while #matches < count do + local from = next(flags) + if not from then + break + end + + local to = M.get_match_end(from) + + if not to then + break + end + + table.insert(matches, { + from = from, + to = to, + first = #matches == 0, + }) + end + + vim.fn.winrestview(view) + return matches +end + +---@param pattern string +---@param flags? string +---@param k? number +function M.get_matches_old(pattern, flags, k) + local view = vim.fn.winsaveview() + + flags = flags or "" + + ---@type Flash.Match[] + local matches = {} + + local orig = pattern + pattern = pattern:gsub("\\z[se]", "") + + local ok, re = pcall(vim.regex, orig .. (vim.go.ignorecase and "\\c" or "")) if not ok then return {} -- invalid pattern, bail out end @@ -80,8 +149,8 @@ function M.get_matches(pattern, flags, k) local buf = vim.api.nvim_get_current_buf() local count = vim.fn.searchcount({ pattern = pattern, recompute = true, maxcount = k }).total or 0 - local function next(f) - local from = vim.fn.searchpos(pattern, f) + local function next(f, p) + local from = vim.fn.searchpos(p or pattern, f) return from[1] ~= 0 and { from[1], from[2] - 1 } or nil end @@ -90,8 +159,23 @@ function M.get_matches(pattern, flags, k) if not from then break end - local col_start, col_end = re:match_line(buf, from[1] - 1, from[2]) - local to = col_start and { from[1], math.max(col_end + from[2] - 1, 0) } + + from = next("cn") + + -- ---@type number[] | nil + local to + if M.api == "vim" then + local line = vim.api.nvim_buf_get_lines(buf, from[1] - 1, from[1], false)[1] + local m = vim.fn.matchstrpos(line, orig, from[2]) + to = m[2] ~= -1 and { from[1], math.max(m[3] - 1, 0) } + from[2] = m[2] + elseif M.api == "regex" then + local col_start, col_end = re:match_line(buf, from[1] - 1, from[2]) + to = col_start and { from[1], math.max(col_end + from[2] - 1, 0) } + from[2] = from[2] + (col_start or 0) + else + error("invalid api " .. M.api) + end -- `s` will be `nil` or non-zero for multi-line matches, -- Since this is a non-zero-width match, we can use `searchpos` diff --git a/tests/search/searcher_spec.lua b/tests/search/searcher_spec.lua index 1dec37f..da28aaa 100644 --- a/tests/search/searcher_spec.lua +++ b/tests/search/searcher_spec.lua @@ -87,6 +87,36 @@ describe("searcher", function() }, matches) end) + it("handles '\\Vi\\zs\\.'", function() + set([[ + line1 + line2 + line3 + ]]) + + local matches = Searcher.get_matches([[\Vi\zs\m.]]) + assert.same({ + { first = true, from = { 1, 2 }, to = { 1, 2 } }, + { first = false, from = { 2, 2 }, to = { 2, 2 } }, + { first = false, from = { 3, 2 }, to = { 3, 2 } }, + }, matches) + end) + + it("handles ^", function() + set([[ + foobar + line1 + line2 + ]]) + + local matches = Searcher.get_matches("^") + assert.same({ + { first = true, from = { 2, 0 }, to = { 2, 0 } }, + { first = false, from = { 3, 0 }, to = { 3, 0 } }, + { first = false, from = { 1, 0 }, to = { 1, 0 } }, + }, matches) + end) + it("handles ^", function() set([[ foobar