-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(async): implement the concept of futures and handle async operat…
…ions via coroutines
- Loading branch information
Showing
2 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
---@class Future | ||
---@field _state 'pending' | 'fulfilled' | 'rejected' | ||
---@field _value any | ||
---@field _callbacks { on_fulfilled: fun(value: any), on_rejected: fun(reason: any) }[] | ||
local Future = {} | ||
local mt = { __index = Future } | ||
|
||
function Future.new(executor) | ||
local self = setmetatable({}, mt) | ||
self._state = 'pending' | ||
self._value = nil | ||
self._callbacks = {} | ||
|
||
local function resolve(value) | ||
if self._state ~= 'pending' then return end | ||
self._state = 'fulfilled' | ||
self._value = value | ||
for _, callback in ipairs(self._callbacks) do | ||
callback.on_fulfilled(value) | ||
end | ||
end | ||
|
||
local function reject(reason) | ||
if self._state ~= 'pending' then return end | ||
self._state = 'rejected' | ||
self._value = reason | ||
for _, callback in ipairs(self._callbacks) do | ||
callback.on_rejected(reason) | ||
end | ||
end | ||
|
||
xpcall(function() executor(resolve, reject) end, function(err) | ||
require('cord.plugin.log').tracecb( | ||
function() | ||
return 'Error in executor: ' .. err .. '\n' .. debug.traceback() | ||
end | ||
) | ||
reject(err) | ||
end) | ||
|
||
return self | ||
end | ||
|
||
function Future:and_then(on_fulfilled, on_rejected) | ||
local current = coroutine.running() | ||
if not current then | ||
require('cord.plugin.log').errorcb( | ||
function() | ||
return 'Future:and_then must be called within a coroutine\n' | ||
.. debug.traceback() | ||
end | ||
) | ||
return | ||
end | ||
|
||
return Future.new(function(resolve, reject) | ||
local function handle_callback(callback, resolve, reject, value) | ||
if type(callback) ~= 'function' then | ||
if self._state == 'fulfilled' then | ||
resolve(value or self._value) | ||
else | ||
reject(value or self._value) | ||
end | ||
return | ||
end | ||
|
||
local success, result = xpcall( | ||
function() return callback(value or self._value) end, | ||
function(err) | ||
require('cord.plugin.log').tracecb( | ||
function() | ||
return 'Error in callback: ' .. err .. '\n' .. debug.traceback() | ||
end | ||
) | ||
end | ||
) | ||
|
||
if not success then | ||
reject(result) | ||
return | ||
end | ||
|
||
if type(result) == 'table' and result._state then | ||
result:and_then(resolve, reject) | ||
else | ||
resolve(result) | ||
end | ||
end | ||
|
||
if self._state == 'pending' then | ||
table.insert(self._callbacks, { | ||
on_fulfilled = function(value) | ||
handle_callback(on_fulfilled, resolve, reject, value) | ||
end, | ||
on_rejected = function(reason) | ||
handle_callback(on_rejected, resolve, reject, reason) | ||
end, | ||
}) | ||
else | ||
vim.defer_fn(function() | ||
if self._state == 'fulfilled' then | ||
handle_callback(on_fulfilled, resolve, reject) | ||
else | ||
handle_callback(on_rejected, resolve, reject) | ||
end | ||
end, 0) | ||
end | ||
end) | ||
end | ||
|
||
function Future:catch(on_rejected) return self:and_then(nil, on_rejected) end | ||
|
||
function Future.await(future) | ||
local co = coroutine.running() | ||
if not co then | ||
require('cord.plugin.log').errorcb( | ||
function() | ||
return 'Future:await must be called within a coroutine\n' | ||
.. debug.traceback() | ||
end | ||
) | ||
end | ||
|
||
future:and_then( | ||
function(value) coroutine.resume(co, true, value) end, | ||
function(reason) coroutine.resume(co, false, reason) end | ||
) | ||
|
||
local success, result = coroutine.yield() | ||
if success then | ||
return result | ||
else | ||
error(result) | ||
end | ||
end | ||
|
||
function Future.get(future) | ||
local co = coroutine.running() | ||
if not co then | ||
require('cord.plugin.log').errorcb( | ||
function() | ||
return 'Future:get must be called within a coroutine\n' | ||
.. debug.traceback() | ||
end | ||
) | ||
end | ||
|
||
future:and_then( | ||
function(value) coroutine.resume(co, true, value) end, | ||
function(reason) coroutine.resume(co, false, reason) end | ||
) | ||
|
||
local success, result = coroutine.yield() | ||
if success then | ||
return result | ||
else | ||
return nil, result | ||
end | ||
end | ||
|
||
function Future.all(futures) | ||
return Future.new(function(resolve, reject) | ||
local results = {} | ||
local completed = 0 | ||
for i, future in ipairs(futures) do | ||
future | ||
:and_then(function(result) | ||
results[i] = result | ||
completed = completed + 1 | ||
if completed == #futures then resolve(results) end | ||
end) | ||
:catch(reject) | ||
end | ||
end) | ||
end | ||
|
||
return Future |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
local Future = require 'cord.core.async.future' | ||
|
||
local Async = {} | ||
|
||
function Async.wrap(fn) | ||
return function(...) | ||
local args = { ... } | ||
return Future.new(function(resolve, reject) | ||
local current = coroutine.running() | ||
if not current then | ||
require('cord.plugin.log').errorcb( | ||
function() | ||
return 'async.wrap must be called within a coroutine\n' | ||
.. debug.traceback() | ||
end | ||
) | ||
return | ||
end | ||
|
||
local success, result = xpcall(function() | ||
---@diagnostic disable-next-line: deprecated | ||
local unpack = table.unpack or unpack | ||
return fn(unpack(args)) | ||
end, function(err) | ||
require('cord.plugin.log').tracecb( | ||
function() | ||
return 'Error in async.wrap: ' .. err .. '\n' .. debug.traceback() | ||
end | ||
) | ||
end) | ||
|
||
if not success then | ||
reject(result) | ||
return | ||
end | ||
|
||
if type(result) == 'table' and result._state then | ||
result:and_then(resolve, reject) | ||
else | ||
resolve(result) | ||
end | ||
end) | ||
end | ||
end | ||
|
||
function Async.run(fn) | ||
local co = coroutine.create(fn) | ||
local function resume(success, ...) | ||
if not success then | ||
error(...) | ||
return | ||
end | ||
|
||
local ret = { coroutine.resume(co, ...) } | ||
success = table.remove(ret, 1) | ||
|
||
if success then | ||
if coroutine.status(co) ~= 'dead' then | ||
local future = ret[1] | ||
if future then | ||
if type(future) == 'table' and future._state then | ||
future:and_then(function(value) | ||
if coroutine.status(co) ~= 'dead' then resume(true, value) end | ||
end, function(err) resume(false, err) end) | ||
else | ||
resume(true, future) | ||
end | ||
end | ||
end | ||
else | ||
error(ret[1]) | ||
end | ||
end | ||
|
||
resume(true) | ||
return co | ||
end | ||
|
||
return Async |