From 0409700585fb88d16d9ccd729009892825e07c85 Mon Sep 17 00:00:00 2001 From: Jonathan Halterman Date: Tue, 4 Feb 2025 17:58:59 -0800 Subject: [PATCH] Improve DynamicSemaphore --- pkg/util/adaptivelimiter/adaptivelimiter.go | 6 +-- pkg/util/sync/semaphore.go | 52 ++++++++++-------- pkg/util/sync/semaphore_test.go | 58 +++++++++++++++++++-- 3 files changed, 87 insertions(+), 29 deletions(-) diff --git a/pkg/util/adaptivelimiter/adaptivelimiter.go b/pkg/util/adaptivelimiter/adaptivelimiter.go index 33fa8fe7e4e..e067af026e1 100644 --- a/pkg/util/adaptivelimiter/adaptivelimiter.go +++ b/pkg/util/adaptivelimiter/adaptivelimiter.go @@ -181,7 +181,7 @@ func (l *adaptiveLimiter) AcquirePermit(ctx context.Context) (Permit, error) { return &recordingPermit{ limiter: l, startTime: time.Now(), - currentInflight: l.semaphore.Inflight(), + currentInflight: l.semaphore.Used(), }, nil } @@ -192,7 +192,7 @@ func (l *adaptiveLimiter) TryAcquirePermit() (Permit, bool) { return &recordingPermit{ limiter: l, startTime: time.Now(), - currentInflight: l.semaphore.Inflight(), + currentInflight: l.semaphore.Used(), }, true } @@ -207,7 +207,7 @@ func (l *adaptiveLimiter) Limit() int { } func (l *adaptiveLimiter) Inflight() int { - return l.semaphore.Inflight() + return l.semaphore.Used() } func (l *adaptiveLimiter) Blocked() int { diff --git a/pkg/util/sync/semaphore.go b/pkg/util/sync/semaphore.go index 2b5dd1b74dc..bfe34cf961c 100644 --- a/pkg/util/sync/semaphore.go +++ b/pkg/util/sync/semaphore.go @@ -14,7 +14,7 @@ import ( type DynamicSemaphore struct { mu sync.Mutex size int64 - cur int64 + used int64 waiters list.List } @@ -32,6 +32,17 @@ func NewDynamicSemaphore(n int64) *DynamicSemaphore { func (s *DynamicSemaphore) SetSize(n int64) { s.mu.Lock() defer s.mu.Unlock() + + // If capacity increased, wake up waiters that can now acquire + if n > s.size { + for s.used < n && s.waiters.Len() > 0 { + s.used++ + next := s.waiters.Front() + s.waiters.Remove(next) + close(next.Value.(chan struct{})) + } + } + s.size = n } @@ -42,8 +53,8 @@ func (s *DynamicSemaphore) SetSize(n int64) { // If ctx is already done, Acquire may still succeed without blocking. func (s *DynamicSemaphore) Acquire(ctx context.Context) error { s.mu.Lock() - if s.cur < s.size { - s.cur++ + if s.used < s.size { + s.used++ s.mu.Unlock() return nil } @@ -79,8 +90,8 @@ func (s *DynamicSemaphore) TryAcquire() bool { s.mu.Lock() defer s.mu.Unlock() - if s.cur < s.size { - s.cur++ + if s.used < s.size { + s.used++ return true } return false @@ -92,42 +103,39 @@ func (s *DynamicSemaphore) Release() { s.mu.Lock() defer s.mu.Unlock() - if s.cur < 1 { - panic("semaphore: bad release") + if s.used < 1 { + panic("DynamicSemaphore: unexpected release") } - next := s.waiters.Front() + waiter := s.waiters.Front() - // If there are no waiters, or if we recently resized and cur is too high, just decrement and we're done - if next == nil || s.cur > s.size { - s.cur-- + // If there are no waiters or if we recently resized and used is too high, just decrement and we're done + if waiter == nil || s.used > s.size { + s.used-- return } - // Need to yield our slot to the next waiter. - // Remove them from the list - s.waiters.Remove(next) - - // And trigger it's chan before we release the lock - close(next.Value.(chan struct{})) - // Note we _don't_ decrement inflight since the slot was yielded directly. + // Remove and release the waiter + s.waiters.Remove(waiter) + close(waiter.Value.(chan struct{})) } func (s *DynamicSemaphore) IsFull() bool { s.mu.Lock() defer s.mu.Unlock() - return s.cur >= s.size + return s.used >= s.size } -// Waiters returns how many callers are blocked waiting for the semaphore. +// Waiters returns how many callers are blocked waiting for permits. func (s *DynamicSemaphore) Waiters() int { s.mu.Lock() defer s.mu.Unlock() return s.waiters.Len() } -func (s *DynamicSemaphore) Inflight() int { +// Used returns how many slots are currently in use. +func (s *DynamicSemaphore) Used() int { s.mu.Lock() defer s.mu.Unlock() - return int(s.cur) + return int(s.used) } diff --git a/pkg/util/sync/semaphore_test.go b/pkg/util/sync/semaphore_test.go index e200fc8380e..e3bfbe14e5b 100644 --- a/pkg/util/sync/semaphore_test.go +++ b/pkg/util/sync/semaphore_test.go @@ -7,6 +7,7 @@ package sync import ( "context" + "fmt" "math/rand" "runtime" "sync" @@ -74,6 +75,56 @@ func checkAcquire(t *testing.T, sem *DynamicSemaphore, wantAcquire bool) { } } +func TestDynamicSemaphore_SetSize(t *testing.T) { + t.Parallel() + + t.Run("should wake waiter when setting larger size", func(t *testing.T) { + s := NewDynamicSemaphore(1) + require.NoError(t, s.Acquire(context.Background())) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + _ = s.Acquire(context.Background()) + fmt.Println("done") + wg.Done() + }() + go func() { + _ = s.Acquire(context.Background()) + fmt.Println("done") + wg.Done() + }() + + assert.Eventually(t, func() bool { + return s.Waiters() == 2 + }, 100*time.Millisecond, 10*time.Millisecond) + require.Equal(t, 2, s.Waiters()) + + // Increase size which should release waiters + s.SetSize(3) + wg.Wait() + assert.Equal(t, 0, s.Waiters()) + }) + + t.Run("should block acquires when setting smaller size", func(t *testing.T) { + s := NewDynamicSemaphore(3) + for i := 0; i < 3; i++ { + require.NoError(t, s.Acquire(context.Background())) + } + + s.SetSize(1) + for i := 0; i < 3; i++ { + s.Release() + } + + require.NoError(t, s.Acquire(context.Background())) + + // Should timeout while acquiring permit + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) + require.Error(t, s.Acquire(ctx)) + }) +} + func TestDynamicSemaphore_Acquire(t *testing.T) { t.Parallel() @@ -203,12 +254,10 @@ func TestDynamicSemaphore_IsFull(t *testing.T) { } func TestDynamicSemaphore_Waiters(t *testing.T) { - overloadDuration := 100 * time.Millisecond s := NewDynamicSemaphore(1) err := s.Acquire(context.Background()) require.NoError(t, err) - // When go func() { _ = s.Acquire(context.Background()) }() @@ -216,8 +265,9 @@ func TestDynamicSemaphore_Waiters(t *testing.T) { _ = s.Acquire(context.Background()) }() - time.Sleep(overloadDuration) - assert.Equal(t, 2, s.Waiters()) + assert.Eventually(t, func() bool { + return s.Waiters() == 2 + }, 100*time.Millisecond, 10*time.Millisecond) s.Release() assert.Equal(t, 1, s.Waiters()) s.Release()