diff --git a/lib/std/Thread/Semaphore.zig b/lib/std/Thread/Semaphore.zig index be9234197b4d..36f68e42f5cf 100644 --- a/lib/std/Thread/Semaphore.zig +++ b/lib/std/Thread/Semaphore.zig @@ -41,8 +41,8 @@ pub fn tryWaitUntil(self: *Semaphore, deadline: Instant) error{TimedOut}!void { return self.impl.wait(deadline); } -pub fn post(self: *Semaphore) void { - return self.impl.post(); +pub fn post(self: *Semaphore, update: u31) void { + return self.impl.post(update); } const SerialSemaphoreImpl = struct { @@ -69,34 +69,40 @@ const SerialSemaphoreImpl = struct { return error.TimedOut; } - fn post(self: *Impl) void { - self.count += 1; + fn post(self: *Impl, update: u31) void { + var new_count = @as(u31, self.count); + if (@addWithOverflow(u31, new_count, update, &new_count)) { + unreachable; // post overflow + } + + const updated = @as(u32, new_count); + self.count = updated; } }; +/// Semaphore implementation adapted from: +/// https://softwareengineering.stackexchange.com/a/362533 const SemaphoreImpl = struct { - /// state holds the semaphore value in the upper 31 bits - /// and reserves the lowest bit to indicate if theres any waiting threads. - state: Atomic(u32) = Atomic(u32).init(0), + count: Atomic(i32) = Atomic(i32).init(0), + waiters: Atomic(u32) = Atomic(u32).init(0), - const IS_WAITING: u32 = 1 << 0; - const COUNT_SHIFT: std.math.Log2Int(u32) = @ctz(u32, IS_WAITING) + 1; + const LOCKED: i32 = 0; + const CONTENDED: i32 = -1; fn init(count: u31) Impl { - const state = @as(u32, count) << COUNT_SHIFT; - return .{ .state = Atomic(u32).init(state) }; + return .{ .count = Atomic(i32).init(count) }; } fn tryWait(self: *Impl) bool { - var state = self.state.load(.Monotonic); + var count = self.count.load(.Monotonic); while (true) { - if (state >> COUNT_SHIFT == 0) { + if (count <= LOCKED) { return false; } - state = self.state.tryCompareAndSwap( - state, - state - (1 << COUNT_SHIFT), + count = self.count.tryCompareAndSwap( + count, + count - 1, .Acquire, .Monotonic, ) orelse return true; @@ -104,102 +110,116 @@ const SemaphoreImpl = struct { } fn wait(self: *Impl, deadline: ?Instant) error{TimedOut}!void { + // Try to acquire a semaphore count wait-free for the uncontended fast path + const count = self.count.load(.Monotonic); + if (count > LOCKED) { + _ = self.count.tryCompareAndSwap( + count, + count - 1, + .Acquire, + .Monotonic, + ) orelse return; + } + + return self.waitSlow(deadline); + } + + fn waitSlow(self: *Impl, deadline: ?Instant) error{TimedOut}!void { @setCold(true); + // Try to acquire a semaphore count, spinning on the count a bit + // in hopes of a post() suppliying a count to avoid waiting below. + // Bails if there's already threads waiting on the count. var spin: u8 = 10; - var was_waiting = false; - var state = self.state.load(.Monotonic); - while (true) { - if (state >> COUNT_SHIFT != 0) { - state = self.state.tryCompareAndSwap( - state, - (state - (1 << COUNT_SHIFT)) | @boolToInt(was_waiting), + while (spin > 0) : (spin -= 1) { + std.atomic.spinLoopHint(); + + var count = self.count.load(.Monotonic); + if (count == LOCKED) continue; + if (count < LOCKED) break; + + while (count > LOCKED) { + count = self.count.tryCompareAndSwap( + count, + count - 1, .Acquire, .Monotonic, ) orelse return; - continue; } + } + + // Register that this thread is waiting on the semaphore + var waiters = self.waiters.fetchAdd(1, .Monotonic); + assert(waiters != std.math.maxInt(u32)); - if (state & IS_WAITING == 0) { - // Spin a little bit if there's no waiting threads. - // This creates low-latency wake-up if a post() is fast enough. - if (spin > 0) { - spin -= 1; - std.atomic.spinLoopHint(); - state = self.state.load(.Monotonic); + // Once we're done waiting, no matter the condition, unregister ourselves from waiting on the semaphore + defer { + waiters = self.waiters.fetchSub(1, .Monotonic); + assert(waiters != 0); + } + + while (true) { + var count = self.count.load(.Monotonic); + while (true) { + // Grab a semaphore count if there is one + if (count > LOCKED) { + count = self.count.tryCompareAndSwap( + count, + count - 1, + .Acquire, + .Monotonic, + ) orelse return; continue; } - // Make sure that IS_WAITING is set before waiting on the state. - // This only updates the state so no Acquire barrier is needed. - if (self.state.tryCompareAndSwap( - state, - state | IS_WAITING, + // Ensure that the semaphore count is CONTENDED before sleeping. + if (count == CONTENDED) break; + count = self.count.tryCompareAndSwap( + count, + CONTENDED, .Monotonic, .Monotonic, - )) |updated| { - state = updated; - continue; - } + ) orelse break; } try Futex.wait( - &self.state, - state | IS_WAITING, + &self.count, + CONTENDED, deadline, ); - - spin = 10; - was_waiting = true; - state = self.state.load(.Monotonic); } } - fn post(self: *Impl) void { - const state = blk: { - // If on x86, only post the count. We unset the IS_WAITING bit in the slow path. - // A `lock xadd` here is often faster than the `lock cmpxchg` loop down below. - if (comptime arch.isX86()) { - break :blk self.state.fetchAdd(1 << COUNT_SHIFT, .Release); - } + fn post(self: *Impl, update: u31) void { + var waiters: u32 = 0; + var count = self.count.load(.Monotonic); - var state = self.state.load(.Monotonic); - while (true) { - state = self.state.tryCompareAndSwap( - state, - (state + (1 << COUNT_SHIFT)) & ~IS_WAITING, - .Release, - .Monotonic, - ) orelse break :blk state; + // Bump the semaphore count while checking if theres any threads to wake up. + while (true) { + var new_count = count + @boolToInt(count < LOCKED); + if (@addWithOverflow(i32, new_count, @as(i32, update), &new_count)) { + unreachable; // semaphore count overflow } - }; - - assert((state >> COUNT_SHIFT) > 0); - if (state & IS_WAITING != 0) { - self.postSlow(); - } - } - - fn postSlow(self: *Impl) void { - @setCold(true); - // The fast path only increments the semaphore value without - // removing the IS_WAITING bit so that must be done separately. - // - // Multiple post() threads could observe the IS_WAITING bit set - // so this can race and should have one thread perform the wake-up. - if (comptime arch.isX86()) { - const was_waiting = self.state.bitUnset(@ctz(u32, IS_WAITING), .Monotonic); - if (was_waiting == 0) { - return; - } + waiters = self.waiters.load(.Monotonic); + count = self.count.tryCompareAndSwap( + count, + new_count, + .Release, + .Monotonic, + ) orelse break; } - const num_waiters = 1; - Futex.wake(&self.state, num_waiters); + // Wake up some semaphore waiters without touching self.count + // as the Semaphore could have consumed the update above and deallocated itself. + if ((count < LOCKED) or (waiters > 0)) { + Futex.wake(&self.count, @as(u32, update)); + } } }; + + test "Semaphore" { return error.TODO; }