aoc-2024/lib/spice.zig
2024-12-08 16:14:41 +01:00

545 lines
17 KiB
Zig

const std = @import("std");
// The overall design of Spice is as follows:
// - ThreadPool spawns threads which acts as background workers.
// - A Worker, while executing, will share one piece of work (`shared_job`).
// - A Worker, while waiting, will look for shared jobs by other workers.
pub const ThreadPoolConfig = struct {
/// The number of background workers. If `null` this chooses a sensible
/// default based on your system (i.e. number of cores).
background_worker_count: ?usize = null,
/// How often a background thread is interrupted to find more work.
heartbeat_interval: usize = 100 * std.time.ns_per_us,
};
pub const ThreadPool = struct {
allocator: std.mem.Allocator,
mutex: std.Thread.Mutex = .{},
/// List of all workers.
workers: std.ArrayListUnmanaged(*Worker) = .{},
/// List of all background workers.
background_threads: std.ArrayListUnmanaged(std.Thread) = .{},
/// The background thread which beats.
heartbeat_thread: ?std.Thread = null,
/// A pool for the JobExecuteState, to minimize allocations.
execute_state_pool: std.heap.MemoryPool(JobExecuteState),
/// This is used to signal that more jobs are now ready.
job_ready: std.Thread.Condition = .{},
/// This is used to wait for the background workers to be available initially.
workers_ready: std.Thread.Semaphore = .{},
/// This is set to true once we're trying to stop.
is_stopping: bool = false,
/// A timer which we increment whenever we share a job.
/// This is used to prioritize always picking the oldest job.
time: usize = 0,
heartbeat_interval: usize,
pub fn init(allocator: std.mem.Allocator) ThreadPool {
return ThreadPool{
.allocator = allocator,
.execute_state_pool = std.heap.MemoryPool(JobExecuteState).init(allocator),
.heartbeat_interval = undefined,
};
}
/// Starts the thread pool. This should only be invoked once.
pub fn start(self: *ThreadPool, config: ThreadPoolConfig) void {
const actual_count = config.background_worker_count orelse (std.Thread.getCpuCount() catch @panic("getCpuCount error")) - 1;
self.heartbeat_interval = config.heartbeat_interval;
self.background_threads.ensureUnusedCapacity(self.allocator, actual_count) catch @panic("OOM");
self.workers.ensureUnusedCapacity(self.allocator, actual_count) catch @panic("OOM");
for (0..actual_count) |_| {
const thread = std.Thread.spawn(.{}, backgroundWorker, .{self}) catch @panic("spawn error");
self.background_threads.append(self.allocator, thread) catch @panic("OOM");
}
self.heartbeat_thread = std.Thread.spawn(.{}, heartbeatWorker, .{self}) catch @panic("spawn error");
// Wait for all of them to be ready:
for (0..actual_count) |_| {
self.workers_ready.wait();
}
}
pub fn deinit(self: *ThreadPool) void {
// Tell all background workers to stop:
{
self.mutex.lock();
defer self.mutex.unlock();
self.is_stopping = true;
self.job_ready.broadcast();
}
// Wait for background workers to stop:
for (self.background_threads.items) |thread| {
thread.join();
}
if (self.heartbeat_thread) |thread| {
thread.join();
}
// Free up memory:
self.background_threads.deinit(self.allocator);
self.workers.deinit(self.allocator);
self.execute_state_pool.deinit();
self.* = undefined;
}
fn backgroundWorker(self: *ThreadPool) void {
var w = Worker{ .pool = self };
var first = true;
self.mutex.lock();
defer self.mutex.unlock();
self.workers.append(self.allocator, &w) catch @panic("OOM");
// We don't bother removing ourselves from the workers list of exit since
// this only happens when the whole thread pool is destroyed anyway.
while (true) {
if (self.is_stopping) break;
if (self._popReadyJob()) |job| {
// Release the lock while executing the job.
self.mutex.unlock();
defer self.mutex.lock();
w.executeJob(job);
continue; // Go straight to another attempt of finding more work.
}
if (first) {
// Register that we are ready.
self.workers_ready.post();
first = false;
}
self.job_ready.wait(&self.mutex);
}
}
fn heartbeatWorker(self: *ThreadPool) void {
// We try to make sure that each worker is being heartbeat at the
// fixed interval by going through the workers-list one by one.
var i: usize = 0;
while (true) {
var to_sleep: u64 = self.heartbeat_interval;
{
self.mutex.lock();
defer self.mutex.unlock();
if (self.is_stopping) break;
const workers = self.workers.items;
if (workers.len > 0) {
i %= workers.len;
workers[i].heartbeat.store(true, .monotonic);
i += 1;
to_sleep /= workers.len;
}
}
std.time.sleep(to_sleep);
}
}
pub fn call(self: *ThreadPool, comptime T: type, func: anytype, arg: anytype) T {
// Create an one-off worker:
var worker = Worker{ .pool = self };
{
self.mutex.lock();
defer self.mutex.unlock();
self.workers.append(self.allocator, &worker) catch @panic("OOM");
}
defer {
self.mutex.lock();
defer self.mutex.unlock();
for (self.workers.items, 0..) |worker_ptr, idx| {
if (worker_ptr == &worker) {
_ = self.workers.swapRemove(idx);
break;
}
}
}
var t = worker.begin();
return t.call(T, func, arg);
}
/// The core logic of the heartbeat. Every executing worker invokes this periodically.
fn heartbeat(self: *ThreadPool, worker: *Worker) void {
@setCold(true);
self.mutex.lock();
defer self.mutex.unlock();
if (worker.shared_job == null) {
if (worker.job_head.shift()) |job| {
// Allocate an execute state for it:
const execute_state = self.execute_state_pool.create() catch @panic("OOM");
execute_state.* = .{
.result = undefined,
};
job.setExecuteState(execute_state);
worker.shared_job = job;
worker.job_time = self.time;
self.time += 1;
self.job_ready.signal(); // wake up one thread
}
}
worker.heartbeat.store(false, .monotonic);
}
/// Waits for (a shared) job to be completed.
/// This returns `false` if it turns out the job was not actually started.
fn waitForJob(self: *ThreadPool, worker: *Worker, job: *Job) bool {
const exec_state = job.getExecuteState();
{
self.mutex.lock();
defer self.mutex.unlock();
if (worker.shared_job == job) {
// This is the job we attempted to share with someone else, but before someone picked it up.
worker.shared_job = null;
self.execute_state_pool.destroy(exec_state);
return false;
}
// Help out by picking up more work if it's available.
while (!exec_state.done.isSet()) {
if (self._popReadyJob()) |other_job| {
self.mutex.unlock();
defer self.mutex.lock();
worker.executeJob(other_job);
} else {
break;
}
}
}
exec_state.done.wait();
return true;
}
/// Finds a job that's ready to be executed.
fn _popReadyJob(self: *ThreadPool) ?*Job {
var best_worker: ?*Worker = null;
for (self.workers.items) |other_worker| {
if (other_worker.shared_job) |_| {
if (best_worker) |best| {
if (other_worker.job_time < best.job_time) {
// Pick this one instead if it's older.
best_worker = other_worker;
}
} else {
best_worker = other_worker;
}
}
}
if (best_worker) |worker| {
defer worker.shared_job = null;
return worker.shared_job;
}
return null;
}
fn destroyExecuteState(self: *ThreadPool, exec_state: *JobExecuteState) void {
self.mutex.lock();
defer self.mutex.unlock();
self.execute_state_pool.destroy(exec_state);
}
};
pub const Worker = struct {
pool: *ThreadPool,
job_head: Job = Job.head(),
/// A job (guaranteed to be in executing state) which other workers can pick up.
shared_job: ?*Job = null,
/// The time when the job was shared. Used for prioritizing which job to pick up.
job_time: usize = 0,
/// The heartbeat value. This is set to `true` to signal we should do a heartbeat action.
heartbeat: std.atomic.Value(bool) = std.atomic.Value(bool).init(true),
pub fn begin(self: *Worker) Task {
std.debug.assert(self.job_head.isTail());
return Task{
.worker = self,
.job_tail = &self.job_head,
};
}
fn executeJob(self: *Worker, job: *Job) void {
var t = self.begin();
job.handler.?(&t, job);
}
};
pub const Task = struct {
worker: *Worker,
job_tail: *Job,
pub inline fn tick(self: *Task) void {
if (self.worker.heartbeat.load(.monotonic)) {
self.worker.pool.heartbeat(self.worker);
}
}
pub inline fn call(self: *Task, comptime T: type, func: anytype, arg: anytype) T {
return callWithContext(
self.worker,
self.job_tail,
T,
func,
arg,
);
}
};
// The following function's signature is actually extremely critical. We take in all of
// the task state (worker, last_heartbeat, job_tail) as parameters. The reason for this
// is that Zig/LLVM is really good at passing parameters in registers, but struggles to
// do the same for "fields in structs". In addition, we then return the changed value
// of last_heartbeat and job_tail.
fn callWithContext(
worker: *Worker,
job_tail: *Job,
comptime T: type,
func: anytype,
arg: anytype,
) T {
var t = Task{
.worker = worker,
.job_tail = job_tail,
};
t.tick();
return @call(.always_inline, func, .{
&t,
arg,
});
}
pub const JobState = enum {
pending,
queued,
executing,
};
// A job represents something which _potentially_ could be executed on a different thread.
// The jobs forms a doubly-linked list: You call `push` to append a job and `pop` to remove it.
const Job = struct {
handler: ?*const fn (t: *Task, job: *Job) void,
prev_or_null: ?*anyopaque,
next_or_state: ?*anyopaque,
// This struct gets placed on the stack in _every_ frame so we're very cautious
// about the size of it. There's three possible states, but we don't use a union(enum)
// since this would actually increase the size.
//
// 1. pending: handler is null. a/b is undefined.
// 2. queued: handler is set. prev_or_null is `prev`, next_or_state is `next`.
// 3. executing: handler is set. prev_or_null is null, next_or_state is `*JobExecuteState`.
/// Returns a new job which can be used for the head of a list.
fn head() Job {
return Job{
.handler = undefined,
.prev_or_null = null,
.next_or_state = null,
};
}
pub fn pending() Job {
return Job{
.handler = null,
.prev_or_null = undefined,
.next_or_state = undefined,
};
}
pub fn state(self: Job) JobState {
if (self.handler == null) return .pending;
if (self.prev_or_null != null) return .queued;
return .executing;
}
pub fn isTail(self: Job) bool {
return self.next_or_state == null;
}
fn getExecuteState(self: *Job) *JobExecuteState {
std.debug.assert(self.state() == .executing);
return @ptrCast(@alignCast(self.next_or_state));
}
pub fn setExecuteState(self: *Job, execute_state: *JobExecuteState) void {
std.debug.assert(self.state() == .executing);
self.next_or_state = execute_state;
}
/// Pushes the job onto a stack.
fn push(self: *Job, tail: **Job, handler: *const fn (task: *Task, job: *Job) void) void {
std.debug.assert(self.state() == .pending);
defer std.debug.assert(self.state() == .queued);
self.handler = handler;
tail.*.next_or_state = self; // tail.next = self
self.prev_or_null = tail.*; // self.prev = tail
self.next_or_state = null; // self.next = null
tail.* = self; // tail = self
}
fn pop(self: *Job, tail: **Job) void {
std.debug.assert(self.state() == .queued);
std.debug.assert(tail.* == self);
const prev: *Job = @ptrCast(@alignCast(self.prev_or_null));
prev.next_or_state = null; // prev.next = null
tail.* = @ptrCast(@alignCast(self.prev_or_null)); // tail = self.prev
self.* = undefined;
}
fn shift(self: *Job) ?*Job {
const job = @as(?*Job, @ptrCast(@alignCast(self.next_or_state))) orelse return null;
std.debug.assert(job.state() == .queued);
const next: ?*Job = @ptrCast(@alignCast(job.next_or_state));
// Now we have: self -> job -> next.
// If there is no `next` then it means that `tail` actually points to `job`.
// In this case we can't remove `job` since we're not able to also update the tail.
if (next == null) return null;
defer std.debug.assert(job.state() == .executing);
next.?.prev_or_null = self; // next.prev = self
self.next_or_state = next; // self.next = next
// Turn the job into "executing" state.
job.prev_or_null = null;
job.next_or_state = undefined;
return job;
}
};
const max_result_words = 4;
const JobExecuteState = struct {
done: std.Thread.ResetEvent = .{},
result: ResultType,
const ResultType = [max_result_words]u64;
fn resultPtr(self: *JobExecuteState, comptime T: type) *T {
if (@sizeOf(T) > @sizeOf(ResultType)) {
@compileError("value is too big to be returned by background thread");
}
const bytes = std.mem.sliceAsBytes(&self.result);
return std.mem.bytesAsValue(T, bytes);
}
};
pub fn Future(comptime Input: type, Output: type) type {
return struct {
const Self = @This();
job: Job,
input: Input,
pub inline fn init() Self {
return Self{ .job = Job.pending(), .input = undefined };
}
/// Schedules a piece of work to be executed by another thread.
/// After this has been called you MUST call `join` or `tryJoin`.
pub inline fn fork(
self: *Self,
task: *Task,
comptime func: fn (task: *Task, input: Input) Output,
input: Input,
) void {
const handler = struct {
fn handler(t: *Task, job: *Job) void {
const fut: *Self = @fieldParentPtr("job", job);
const exec_state = job.getExecuteState();
const value = t.call(Output, func, fut.input);
exec_state.resultPtr(Output).* = value;
exec_state.done.set();
}
}.handler;
self.input = input;
self.job.push(&task.job_tail, handler);
}
/// Waits for the result of `fork`.
/// This is only safe to call if `fork` was _actually_ called.
/// Use `tryJoin` if you conditionally called it.
pub inline fn join(
self: *Self,
task: *Task,
) ?Output {
std.debug.assert(self.job.state() != .pending);
return self.tryJoin(task);
}
/// Waits for the result of `fork`.
/// This function is safe to call even if you didn't call `fork` at all.
pub inline fn tryJoin(
self: *Self,
task: *Task,
) ?Output {
switch (self.job.state()) {
.pending => return null,
.queued => {
self.job.pop(&task.job_tail);
return null;
},
.executing => return self.joinExecuting(task),
}
}
fn joinExecuting(self: *Self, task: *Task) ?Output {
@setCold(true);
const w = task.worker;
const pool = w.pool;
const exec_state = self.job.getExecuteState();
if (pool.waitForJob(w, &self.job)) {
const result = exec_state.resultPtr(Output).*;
pool.destroyExecuteState(exec_state);
return result;
}
return null;
}
};
}