diff --git a/pkg/util/sync/semaphore.go b/pkg/util/sync/semaphore.go index 2b5dd1b74d..38a9813607 100644 --- a/pkg/util/sync/semaphore.go +++ b/pkg/util/sync/semaphore.go @@ -32,6 +32,17 @@ func NewDynamicSemaphore(n int64) *DynamicSemaphore { func (s *DynamicSemaphore) SetSize(n int64) { s.mu.Lock() defer s.mu.Unlock() + + // If capacity increased, wake up waiters that can now acquire + if n > s.size { + for s.cur < n && s.waiters.Len() > 0 { + s.cur++ + next := s.waiters.Front() + s.waiters.Remove(next) + close(next.Value.(chan struct{})) + } + } + s.size = n } @@ -110,7 +121,7 @@ func (s *DynamicSemaphore) Release() { // And trigger it's chan before we release the lock close(next.Value.(chan struct{})) - // Note we _don't_ decrement inflight since the slot was yielded directly. + // Note we _don't_ decrement cur since the slot was yielded directly. } func (s *DynamicSemaphore) IsFull() bool { diff --git a/pkg/util/sync/semaphore_test.go b/pkg/util/sync/semaphore_test.go index e200fc8380..e3bfbe14e5 100644 --- a/pkg/util/sync/semaphore_test.go +++ b/pkg/util/sync/semaphore_test.go @@ -7,6 +7,7 @@ package sync import ( "context" + "fmt" "math/rand" "runtime" "sync" @@ -74,6 +75,56 @@ func checkAcquire(t *testing.T, sem *DynamicSemaphore, wantAcquire bool) { } } +func TestDynamicSemaphore_SetSize(t *testing.T) { + t.Parallel() + + t.Run("should wake waiter when setting larger size", func(t *testing.T) { + s := NewDynamicSemaphore(1) + require.NoError(t, s.Acquire(context.Background())) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + _ = s.Acquire(context.Background()) + fmt.Println("done") + wg.Done() + }() + go func() { + _ = s.Acquire(context.Background()) + fmt.Println("done") + wg.Done() + }() + + assert.Eventually(t, func() bool { + return s.Waiters() == 2 + }, 100*time.Millisecond, 10*time.Millisecond) + require.Equal(t, 2, s.Waiters()) + + // Increase size which should release waiters + s.SetSize(3) + wg.Wait() + assert.Equal(t, 0, s.Waiters()) + }) + + t.Run("should block acquires when setting smaller size", func(t *testing.T) { + s := NewDynamicSemaphore(3) + for i := 0; i < 3; i++ { + require.NoError(t, s.Acquire(context.Background())) + } + + s.SetSize(1) + for i := 0; i < 3; i++ { + s.Release() + } + + require.NoError(t, s.Acquire(context.Background())) + + // Should timeout while acquiring permit + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + require.Error(t, s.Acquire(ctx)) + }) +} + func TestDynamicSemaphore_Acquire(t *testing.T) { t.Parallel() @@ -203,12 +254,10 @@ func TestDynamicSemaphore_IsFull(t *testing.T) { } func TestDynamicSemaphore_Waiters(t *testing.T) { - overloadDuration := 100 * time.Millisecond s := NewDynamicSemaphore(1) err := s.Acquire(context.Background()) require.NoError(t, err) - // When go func() { _ = s.Acquire(context.Background()) }() @@ -216,8 +265,9 @@ func TestDynamicSemaphore_Waiters(t *testing.T) { _ = s.Acquire(context.Background()) }() - time.Sleep(overloadDuration) - assert.Equal(t, 2, s.Waiters()) + assert.Eventually(t, func() bool { + return s.Waiters() == 2 + }, 100*time.Millisecond, 10*time.Millisecond) s.Release() assert.Equal(t, 1, s.Waiters()) s.Release()