diff --git a/common/channel/shutdown_once.go b/common/channel/shutdown_once.go index 2cbbdcf59f2..8faaa2b10dd 100644 --- a/common/channel/shutdown_once.go +++ b/common/channel/shutdown_once.go @@ -25,6 +25,8 @@ package channel import ( + "context" + "sync" "sync/atomic" ) @@ -41,11 +43,18 @@ type ( IsShutdown() bool // Channel for shutdown notification Channel() <-chan struct{} + // PropagateShutdown propagates the shutdown signal to the input context, canceling it when shutdown is called. + // This is useful when creating contexts for background tasks that should be terminated when the service that + // started them is shutdown. + // The cancel function returned by this method should be called in the same manner as the cancel function + // returned by context.WithCancel. In essence, it should be called in a defer statement. + PropagateShutdown(ctx context.Context) (context.Context, context.CancelFunc) } ShutdownOnceImpl struct { status int32 channel chan struct{} + wg sync.WaitGroup } ) @@ -73,3 +82,20 @@ func (c *ShutdownOnceImpl) IsShutdown() bool { func (c *ShutdownOnceImpl) Channel() <-chan struct{} { return c.channel } + +// PropagateShutdown wraps the given context with a cancel function. We then spawn a goroutine which waits for the +// context to be canceled or for the shutdown channel to be closed. When shutdown is called, the cancel function is +// also called. +func (c *ShutdownOnceImpl) PropagateShutdown(ctx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + c.wg.Add(1) + go func() { + defer c.wg.Done() + select { + case <-c.channel: + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} diff --git a/common/channel/shutdown_once_test.go b/common/channel/shutdown_once_test.go new file mode 100644 index 00000000000..c6416123285 --- /dev/null +++ b/common/channel/shutdown_once_test.go @@ -0,0 +1,85 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package channel + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShutdownOnceImpl_PropagateCancel(t *testing.T) { + t.Run("context is canceled when shutdown is called", func(t *testing.T) { + t.Parallel() + so := NewShutdownOnce() + ctx, cancel := so.PropagateShutdown(context.Background()) + defer cancel() + select { + case <-ctx.Done(): + t.Fatal("ctx should not be done") + default: + } + so.Shutdown() + <-ctx.Done() + err := ctx.Err() + assert.ErrorIs(t, err, context.Canceled) + require.True(t, so.IsShutdown()) + select { + case _, ok := <-so.Channel(): + assert.False(t, ok, "channel should be closed") + default: + t.Fatal("channel should be closed") + } + }) + t.Run("input context is canceled", func(t *testing.T) { + t.Parallel() + so := NewShutdownOnce() + inCtx, cancelInputCtx := context.WithCancel(context.Background()) + outCtx, _ := so.PropagateShutdown(inCtx) + cancelInputCtx() + so.wg.Wait() + require.False(t, so.IsShutdown(), "should not shutdown when input context is canceled") + select { + case <-outCtx.Done(): + default: + t.Fatal("output context should be canceled when input context is canceled") + } + }) + t.Run("output context is canceled", func(t *testing.T) { + t.Parallel() + so := NewShutdownOnce() + _, cancel := so.PropagateShutdown(context.Background()) + cancel() + so.wg.Wait() + require.False(t, so.IsShutdown(), "should not shutdown when spawned context is canceled") + select { + case <-so.Channel(): + t.Fatal("channel should not be closed") + default: + } + }) +}