-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add delay_cancellation
utility function
#12180
Changes from 6 commits
fa97936
b02f8e8
a9550f2
6feaf20
e9210b2
ab9b9e1
6de5030
543fafb
2542f74
a997133
27274dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -686,12 +686,69 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": | |
Synapse logcontext rules. | ||
|
||
Returns: | ||
A new `Deferred`, which will contain the result of the original `Deferred`, | ||
but will not propagate cancellation through to the original. When cancelled, | ||
the new `Deferred` will fail with a `CancelledError` and will not follow the | ||
Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap | ||
the new `Deferred`. | ||
A new `Deferred`, which will contain the result of the original `Deferred`. | ||
The new `Deferred` will not propagate cancellation through to the original. | ||
When cancelled, the new `Deferred` will fail with a `CancelledError`. | ||
|
||
The new `Deferred` will not follow the Synapse logcontext rules and should be | ||
wrapped with `make_deferred_yieldable`. | ||
""" | ||
new_deferred: defer.Deferred[T] = defer.Deferred() | ||
new_deferred: "defer.Deferred[T]" = defer.Deferred() | ||
deferred.chainDeferred(new_deferred) | ||
return new_deferred | ||
|
||
|
||
def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]": | ||
"""Delay cancellation of a `Deferred` until it resolves. | ||
|
||
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not | ||
resolve with a `CancelledError` until the original `Deferred` resolves. | ||
|
||
Args: | ||
deferred: The `Deferred` to protect against cancellation. Must not follow the | ||
Synapse logcontext rules if `all` is `False`. | ||
all: `True` to delay multiple cancellations. `False` to delay only the first | ||
cancellation. | ||
|
||
Returns: | ||
A new `Deferred`, which will contain the result of the original `Deferred`. | ||
The new `Deferred` will not propagate cancellation through to the original. | ||
When cancelled, the new `Deferred` will wait until the original `Deferred` | ||
resolves before failing with a `CancelledError`. | ||
|
||
The new `Deferred` will only follow the Synapse logcontext rules if `all` is | ||
`True` and `deferred` follows the Synapse logcontext rules. Otherwise the new | ||
`Deferred` should be wrapped with `make_deferred_yieldable`. | ||
""" | ||
|
||
def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]: | ||
"""Insert another `Deferred` into the chain to delay cancellation. | ||
|
||
Called when the original `Deferred` resolves or the new `Deferred` is | ||
cancelled. | ||
""" | ||
failure.trap(CancelledError) | ||
|
||
if deferred.called and not deferred.paused: | ||
# The `CancelledError` came from the original `Deferred`. Pass it through. | ||
return failure | ||
|
||
# Construct another `Deferred` that will only fail with the `CancelledError` | ||
# once the original `Deferred` resolves. | ||
delay_deferred: "defer.Deferred[T]" = defer.Deferred() | ||
deferred.chainDeferred(delay_deferred) | ||
|
||
if all: | ||
# Intercept cancellations recursively. Each cancellation will cause another | ||
# `Deferred` to be inserted into the chain. | ||
delay_deferred.addErrback(cancel_errback) | ||
|
||
# Override the result with the `CancelledError`. | ||
delay_deferred.addBoth(lambda _: failure) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be easier to give There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A custom canceller was actually my first idea, but I found that twisted expects the canceller to The custom canceller would certainly be a lot cleaner if it worked! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I have an idea involving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a lot cleaner, thanks for giving me the idea! |
||
|
||
return delay_deferred | ||
|
||
new_deferred: "defer.Deferred[T]" = defer.Deferred() | ||
deferred.chainDeferred(new_deferred) | ||
new_deferred.addErrback(cancel_errback) | ||
return new_deferred |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,9 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import traceback | ||
from typing import Callable | ||
|
||
from parameterized import parameterized_class | ||
|
||
from twisted.internet import defer | ||
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred | ||
|
@@ -23,10 +26,12 @@ | |
LoggingContext, | ||
PreserveLoggingContext, | ||
current_context, | ||
make_deferred_yieldable, | ||
) | ||
from synapse.util.async_helpers import ( | ||
ObservableDeferred, | ||
concurrently_execute, | ||
delay_cancellation, | ||
stop_cancellation, | ||
timeout_deferred, | ||
) | ||
|
@@ -313,13 +318,22 @@ async def caller(): | |
self.successResultOf(d2) | ||
|
||
|
||
class StopCancellationTests(TestCase): | ||
"""Tests for the `stop_cancellation` function.""" | ||
@parameterized_class( | ||
("wrap_deferred",), | ||
[ | ||
(lambda _self, deferred: stop_cancellation(deferred),), | ||
(lambda _self, deferred: delay_cancellation(deferred, all=True),), | ||
], | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is rather ugly. Alternatives are welcome. I previously tried an abstract base class, but trial tried to instantiate it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
did you try giving it a different name? AIUI trial picks the things to instantiate based on their name. (IIRC it's things ending in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (otherwise: rather than using a lambda, just have a boolean, and make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Renaming the class doesn't work. trial seems to look for subclasses of I'm not a fan of using booleans as enums, so I'll switch it to a string. Which is still not an enum but at least makes the meaning clear. |
||
class CancellationWrapperTests(TestCase): | ||
"""Common tests for the `stop_cancellation` and `delay_cancellation` functions.""" | ||
|
||
wrap_deferred: Callable[[TestCase, "Deferred[str]"], "Deferred[str]"] | ||
|
||
def test_succeed(self): | ||
"""Test that the new `Deferred` receives the result.""" | ||
deferred: "Deferred[str]" = Deferred() | ||
wrapper_deferred = stop_cancellation(deferred) | ||
wrapper_deferred = self.wrap_deferred(deferred) | ||
|
||
# Success should propagate through. | ||
deferred.callback("success") | ||
|
@@ -329,14 +343,18 @@ def test_succeed(self): | |
def test_failure(self): | ||
"""Test that the new `Deferred` receives the `Failure`.""" | ||
deferred: "Deferred[str]" = Deferred() | ||
wrapper_deferred = stop_cancellation(deferred) | ||
wrapper_deferred = self.wrap_deferred(deferred) | ||
|
||
# Failure should propagate through. | ||
deferred.errback(ValueError("abc")) | ||
self.assertTrue(wrapper_deferred.called) | ||
self.failureResultOf(wrapper_deferred, ValueError) | ||
self.assertIsNone(deferred.result, "`Failure` was not consumed") | ||
|
||
|
||
class StopCancellationTests(TestCase): | ||
"""Tests for the `stop_cancellation` function.""" | ||
|
||
def test_cancellation(self): | ||
"""Test that cancellation of the new `Deferred` leaves the original running.""" | ||
deferred: "Deferred[str]" = Deferred() | ||
|
@@ -347,11 +365,121 @@ def test_cancellation(self): | |
self.assertTrue(wrapper_deferred.called) | ||
self.failureResultOf(wrapper_deferred, CancelledError) | ||
self.assertFalse( | ||
deferred.called, "Original `Deferred` was unexpectedly cancelled." | ||
deferred.called, "Original `Deferred` was unexpectedly cancelled" | ||
) | ||
|
||
# Now make the original `Deferred` fail. | ||
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed | ||
# in logs. | ||
deferred.errback(ValueError("abc")) | ||
self.assertIsNone(deferred.result, "`Failure` was not consumed") | ||
|
||
|
||
class DelayCancellationTests(TestCase): | ||
"""Tests for the `delay_cancellation` function.""" | ||
|
||
def test_cancellation(self): | ||
"""Test that cancellation of the new `Deferred` waits for the original.""" | ||
deferred: "Deferred[str]" = Deferred() | ||
wrapper_deferred = delay_cancellation(deferred, all=True) | ||
|
||
# Cancel the new `Deferred`. | ||
wrapper_deferred.cancel() | ||
self.assertNoResult(wrapper_deferred) | ||
self.assertFalse( | ||
deferred.called, "Original `Deferred` was unexpectedly cancelled" | ||
) | ||
|
||
# Now make the original `Deferred` fail. | ||
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed | ||
# in logs. | ||
deferred.errback(ValueError("abc")) | ||
self.assertIsNone(deferred.result, "`Failure` was not consumed") | ||
|
||
# Now that the original `Deferred` has failed, we should get a `CancelledError`. | ||
self.failureResultOf(wrapper_deferred, CancelledError) | ||
|
||
def test_suppresses_second_cancellation(self): | ||
"""Test that a second cancellation is suppressed when the `all` flag is set. | ||
|
||
Identical to `test_cancellation` except the new `Deferred` is cancelled twice. | ||
""" | ||
deferred: "Deferred[str]" = Deferred() | ||
wrapper_deferred = delay_cancellation(deferred, all=True) | ||
|
||
# Cancel the new `Deferred`, twice. | ||
wrapper_deferred.cancel() | ||
wrapper_deferred.cancel() | ||
self.assertNoResult(wrapper_deferred) | ||
self.assertFalse( | ||
deferred.called, "Original `Deferred` was unexpectedly cancelled" | ||
) | ||
|
||
# Now make the inner `Deferred` fail. | ||
# Now make the original `Deferred` fail. | ||
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed | ||
# in logs. | ||
deferred.errback(ValueError("abc")) | ||
self.assertIsNone(deferred.result, "`Failure` was not consumed") | ||
|
||
# Now that the original `Deferred` has failed, we should get a `CancelledError`. | ||
self.failureResultOf(wrapper_deferred, CancelledError) | ||
|
||
def test_raises_second_cancellation(self): | ||
"""Test that a second cancellation is instant when the `all` flag is not set.""" | ||
deferred: "Deferred[str]" = Deferred() | ||
wrapper_deferred = delay_cancellation(deferred, all=False) | ||
|
||
# Cancel the new `Deferred`, twice. | ||
wrapper_deferred.cancel() | ||
wrapper_deferred.cancel() | ||
self.failureResultOf(wrapper_deferred, CancelledError) | ||
self.assertFalse( | ||
deferred.called, "Original `Deferred` was unexpectedly cancelled" | ||
) | ||
|
||
# Now make the original `Deferred` fail. | ||
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed | ||
# in logs. | ||
deferred.errback(ValueError("abc")) | ||
self.assertIsNone(deferred.result, "`Failure` was not consumed") | ||
|
||
def test_propagates_cancelled_error(self): | ||
"""Test that a `CancelledError` from the original `Deferred` gets propagated.""" | ||
deferred: "Deferred[str]" = Deferred() | ||
wrapper_deferred = delay_cancellation(deferred, all=True) | ||
|
||
# Fail the original `Deferred` with a `CancelledError`. | ||
cancelled_error = CancelledError() | ||
deferred.errback(cancelled_error) | ||
|
||
# The new `Deferred` should fail with exactly the same `CancelledError`. | ||
self.assertTrue(wrapper_deferred.called) | ||
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) | ||
|
||
def test_preserves_logcontext_when_delaying_multiple_cancellations(self): | ||
"""Test that logging contexts are preserved when the `all` flag is set.""" | ||
blocking_d: "Deferred[None]" = Deferred() | ||
|
||
async def inner(): | ||
await make_deferred_yieldable(blocking_d) | ||
|
||
async def outer(): | ||
with LoggingContext("c") as c: | ||
try: | ||
await delay_cancellation(defer.ensureDeferred(inner()), all=True) | ||
self.fail("`CancelledError` was not raised") | ||
except CancelledError: | ||
self.assertEqual(c, current_context()) | ||
# Succeed with no error, unless the logging context is wrong. | ||
|
||
# Run and block inside `inner()`. | ||
d = defer.ensureDeferred(outer()) | ||
self.assertEqual(SENTINEL_CONTEXT, current_context()) | ||
|
||
d.cancel() | ||
d.cancel() | ||
|
||
# Now unblock. `outer()` will consume the `CancelledError` and check the | ||
# logging context. | ||
blocking_d.callback(None) | ||
self.successResultOf(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm led to wonder whether
all=False
is ever going to be useful? It seems to add complexity here so if we can avoid the need for it that would be good.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea, I was starting to wonder the same. Let's remove the option for now and always delay all cancellations.