Skip to content

Commit

Permalink
Merge pull request #2486 from fesily/plugin-OnNodeCompileFunctionParam
Browse files Browse the repository at this point in the history
Plugin on node compile function param
  • Loading branch information
sumneko authored Jan 24, 2024
2 parents ea3aed4 + 155f831 commit 0a962fc
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 50 deletions.
57 changes: 57 additions & 0 deletions script/plugin.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ local scope = require 'workspace.scope'
local ws = require 'workspace'
local fs = require 'bee.filesystem'

---@class pluginInterfaces
local pluginConfigs = {
-- create plugin for vm module
VM = {
OnCompileFunctionParam = function (next, func, source)
end
}
}

---@class plugin
local m = {}

Expand Down Expand Up @@ -51,6 +60,15 @@ function m.dispatch(event, uri, ...)
return failed == 0, res1, res2
end

function m.getVmPlugin(uri)
local scp = scope.getScope(uri)
local interfaces = scp:get('pluginInterfaces')
if not interfaces then
return
end
return interfaces.VM
end

---@async
---@param scp scope
local function checkTrustLoad(scp)
Expand Down Expand Up @@ -78,6 +96,40 @@ local function checkTrustLoad(scp)
return true
end

local function createMethodGroup(interfaces, key, methods)
local methodGroup = {}

for method in pairs(methods) do
local funcs = setmetatable({}, {
__call = function (t, next, ...)
if #t == 0 then
return next(...)
else
local result
for _, fn in ipairs(t) do
result = fn(next, ...)
end
return result
end
end
})
for _, interface in ipairs(interfaces) do
local func = interface[method]
if not func then
local namespace = interface[key]
if namespace then
func = namespace[method]
end
end
if func then
funcs[#funcs+1] = func
end
end
methodGroup[method] = funcs
end
return #methodGroup>0 and methodGroup or nil
end

---@param uri uri
local function initPlugin(uri)
await.call(function () ---@async
Expand Down Expand Up @@ -148,6 +200,11 @@ local function initPlugin(uri)
end
interfaces[#interfaces+1] = interface
end

for key, config in pairs(pluginConfigs) do
interfaces[key] = createMethodGroup(interfaces, key, config)
end

ws.resetFiles(scp)
end)
end
Expand Down
75 changes: 75 additions & 0 deletions script/plugins/nodeHelper.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
local vm = require 'vm'
local guide = require 'parser.guide'

local _M = {}

---@class node.match.pattern
---@field next node.match.pattern?

local function deepCompare(source, pattern)
local type1, type2 = type(source), type(pattern)
if type1 ~= type2 then
return false
end

if type1 ~= "table" then
return source == pattern
end

for key2, value2 in pairs(pattern) do
local value1 = source[key2]
if value1 == nil or not deepCompare(value1, value2) then
return false
end
end

return true
end

---@param source parser.object
---@param pattern node.match.pattern
---@return boolean
function _M.matchPattern(source, pattern)
if source.type == 'local' then
if source.parent.type == 'funcargs' and source.parent.parent.type == 'function' then
for i, ref in ipairs(source.ref) do
if deepCompare(ref, pattern) then
return true
end
end
end
end
return false
end

local vaildVarRegex = "()([a-zA-Z][a-zA-Z0-9_]*)()"
---创建类型 *.field.field形式的 pattern
---@param pattern string
---@return node.match.pattern?, string?
function _M.createFieldPattern(pattern)
local ret = { next = nil }
local next = ret
local init = 1
while true do
local startpos, matched, endpos
if pattern:sub(1, 1) == "*" then
startpos, matched, endpos = init, "*", init + 1
else
startpos, matched, endpos = vaildVarRegex:match(pattern, init)
end
if not startpos then
break
end
if startpos ~= init then
return nil, "invalid pattern"
end
local field = matched == "*" and { next = nil }
or { field = { type = 'field', matched }, type = 'getfield', next = nil }
next.next = field
next = field
pattern = pattern:sub(endpos)
end
return ret
end

return _M
101 changes: 53 additions & 48 deletions script/vm/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ local rpath = require 'workspace.require-path'
local files = require 'files'
---@class vm
local vm = require 'vm.vm'
local plugin = require 'plugin'

---@class parser.object
---@field _compiledNodes boolean
Expand Down Expand Up @@ -1030,6 +1031,55 @@ local function compileForVars(source, target)
return false
end

---@param source parser.object
local function compileFunctionParam(func, source)
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
local argNode = vm.compileNode(arg)
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
end
end
return true
end
end
end
end
if func.parent.type == 'local' then
local refs = func.parent.ref
local findCall
if refs then
for i, ref in ipairs(refs) do
if ref.parent.type == 'call' then
findCall = ref.parent
break
end
end
end
if findCall and findCall.args then
local index
for i, arg in ipairs(source.parent) do
if arg == source then
index = i
break
end
end
if index then
local callerArg = findCall.args[index]
if callerArg then
vm.setNode(source, vm.compileNode(callerArg))
return true
end
end
end
end
end

---@param source parser.object
local function compileLocal(source)
local myNode = vm.setNode(source, source)
Expand Down Expand Up @@ -1069,56 +1119,11 @@ local function compileLocal(source)
vm.setNode(source, vm.compileNode(setfield.node))
end
end

if source.parent.type == 'funcargs' and not hasMarkDoc and not hasMarkParam then
local func = source.parent.parent
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
local hasDocArg
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
local argNode = vm.compileNode(arg)
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
end
end
hasDocArg = true
end
end
end
end
if not hasDocArg
and func.parent.type == 'local' then
local refs = func.parent.ref
local findCall
if refs then
for i, ref in ipairs(refs) do
if ref.parent.type == 'call' then
findCall = ref.parent
break
end
end
end
if findCall and findCall.args then
local index
for i, arg in ipairs(source.parent) do
if arg == source then
index = i
break
end
end
if index then
local callerArg = findCall.args[index]
if callerArg then
hasDocArg = true
vm.setNode(source, vm.compileNode(callerArg))
end
end
end
end
local vmPlugin = plugin.getVmPlugin(guide.getUri(source))
local hasDocArg = vmPlugin and vmPlugin.OnCompileFunctionParam(compileFunctionParam, func, source)
or compileFunctionParam(func, source)
if not hasDocArg then
vm.setNode(source, vm.declareGlobal('type', 'any'))
end
Expand Down
5 changes: 3 additions & 2 deletions script/workspace/scope.lua
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ function mt:set(k, v)
return v
end

---@param k string
---@return any
---@generic T
---@param k `T`
---@return T
function mt:get(k)
return self._data[k]
end
Expand Down
56 changes: 56 additions & 0 deletions test/plugins/node/test.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
local files = require 'files'
local scope = require 'workspace.scope'
local nodeHelper = require 'plugins.nodeHelper'
local vm = require 'vm'
local guide = require 'parser.guide'


local pattern, msg = nodeHelper.createFieldPattern("*.components")
assert(pattern, msg)

---@param source parser.object
function OnCompileFunctionParam(next, func, source)
if next(func, source) then
return true
end
--从该参数的使用模式来推导该类型
if nodeHelper.matchPattern(source, pattern) then
local type = vm.declareGlobal('type', 'TestClass', TESTURI)
vm.setNode(source, vm.createNode(type, source))
return true
end
end

local myplugin = { OnCompileFunctionParam = OnCompileFunctionParam }

---@diagnostic disable: await-in-sync
local function TestPlugin(script)
local prefix = [[
---@class TestClass
---@field b string
]]
---@param checker fun(state:parser.state)
return function (plugin, checker)
files.open(TESTURI)
files.setText(TESTURI, prefix .. script, true)
scope.getScope(TESTURI):set('pluginInterfaces', plugin)
local state = files.getState(TESTURI)
assert(state)
checker(state)
files.remove(TESTURI)
end
end

TestPlugin [[
local function t(a)
a.components:test()
end
]](myplugin, function (state)
guide.eachSourceType(state.ast, 'local', function (src)
if guide.getKeyName(src) == 'a' then
local node = vm.compileNode(src)
assert(node)
assert(not vm.isUnknown(node))
end
end)
end)
1 change: 1 addition & 0 deletions test/plugins/test.lua
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
require 'plugins.ast.test'
require 'plugins.ffi.test'
require 'plugins.node.test'

0 comments on commit 0a962fc

Please sign in to comment.