Skip to content

Commit

Permalink
std.Thread.Once: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kprotty committed Oct 5, 2021
1 parent c574814 commit d2f94f9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 42 deletions.
110 changes: 70 additions & 40 deletions lib/std/Thread/Once.zig
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
const std = @import("../std.zig");
const target = std.Target.current;
const assert = std.debug.assert;
const testing = std.testing;
const os = std.os;
const c = std.c;

Expand Down Expand Up @@ -29,7 +30,7 @@ else
const SerialImpl = struct {
was_called: bool = false,

fn call(self: *Impl, comptime f: fn () void) void {
fn call(self: *Impl, f: fn () void) void {
if (self.was_called) return;
f();
self.was_called = true;
Expand Down Expand Up @@ -88,53 +89,82 @@ const FutexImpl = extern struct {
const WAITING = 2;
const CALLED = 3;

fn call(self: *Impl, comptime f: fn () void) void {
fn call(self: *Impl, f: fn () void) void {
if (self.state.load(.Acquire) == CALLED) return;
self.callSlow(f);
}

noinline fn callSlow(self: *Impl, comptime f: fn () void) void {
noinline fn callSlow(self: *Impl, f: fn () void) void {
@setCold(true);

var spin = SpinWait{};
// Try to transition from UNCALLED -> CALLING in order to call f().
// Once called, transition to CALLED and wake up waiting threads if WAITING.
var state = self.state.load(.Acquire);
while (true) {
state = switch (state) {
// Transition from UNCALLED -> CALLING in order to invoke f().
// Once done, transition to CALLED and wake any waiting threads if WAITING.
UNCALLED => self.state.tryCompareAndSwap(state, CALLING, .Acquire, .Acquire) orelse {
f();
return switch (self.state.swap(CALLED, .Release)) {
UNCALLED => unreachable, // CALLED function while not CALLING
CALLING => {},
WAITING => Futex.wake(&self.state, std.math.maxInt(u32)),
CALLED => unreachable, // CALLED function when already CALLED
};
},
CALLING => blk: {
// Spin a bit in hopes that the CALLING thread will be finished soon.
if (spin.yield()) {
break :blk self.state.load(.Acquire);
}

// Transition to WAITING to ensure that the CALLING thread wakes us up when it's done.
break :blk self.state.tryCompareAndSwap(
CALLING,
WAITING,
.Acquire,
.Acquire,
) orelse WAITING;
},
WAITING => blk: {
// Wait on the state for the CALLING thread to finish and wake us up.
Futex.wait(&self.state, WAITING, null) catch unreachable;
break :blk self.state.load(.Acquire);
},
CALLED => {
// The function has finally been called
return;
},
if (state == UNCALLED) {
state = self.state.compareAndSwap(UNCALLED, CALLING, .Acquire, .Acquire) orelse {
f();
return switch (self.state.swap(CALLED, .Release)) {
UNCALLED => unreachable, // invoked function while not CALLING
CALLING => {},
WAITING => Futex.wake(&self.state, std.math.maxInt(u32)),
CALLED => unreachable, // invoked function when already CALLED
else => unreachable, // invalid Once state
};
};
}

// Spin a bit on the Once in hopes the thread calling f() finishes quickly.
// Only spin if there are no other threads waiting on the Once (!= WAITING).
var spin = SpinWait{};
while (state == CALLING and spin.yield()) {
state = self.state.load(.Acquire);
}

// If we've spun for too long and the thread calling f() hasn't finished yet,
// then update the state to WAITING to signify that threads are sleeping.
if (state == CALLING) {
state = self.state.compareAndSwap(
CALLING,
WAITING,
.Acquire,
.Acquire,
) orelse WAITING;
}

// Wait on the state until the thread calling f() finishes and wakes us up.
while (state == WAITING) {
Futex.wait(&self.state, WAITING, null) catch unreachable;
state = self.state.load(.Acquire);
}

// It was called.. right?
assert(state == CALLED);
}
};

test "Once" {
const num_threads = 4;
const Context = struct {
var once = Once{};
var number: i32 = 0;

fn inc() void {
number += 1;
}

fn call() void {
once.call(inc);
}
};

if (std.builtin.single_threaded) {
Context.call();
Context.call();
} else {
var threads: [num_threads]std.Thread = undefined;
for (threads) |*t| t.* = try std.Thread.spawn(.{}, Context.call, .{});
for (threads) |t| t.join();
}

try testing.expectEqual(Context.number, 1);
}
4 changes: 2 additions & 2 deletions lib/std/c/darwin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ pub const DISPATCH_TIME_NOW = @as(dispatch_time_t, 0);
pub const DISPATCH_TIME_FOREVER = ~@as(dispatch_time_t, 0);
pub extern "c" fn dispatch_time(when: dispatch_time_t, delta: i64) dispatch_time_t;

const dispatch_once_t = usize;
const dispatch_function_t = fn (?*c_void) callconv(.C) void;
pub const dispatch_once_t = usize;
pub const dispatch_function_t = fn (?*c_void) callconv(.C) void;
pub extern fn dispatch_once_f(
predicate: *dispatch_once_t,
context: ?*c_void,
Expand Down

0 comments on commit d2f94f9

Please sign in to comment.