Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add priority queue #2008

Merged
merged 1 commit into from
Feb 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ set(ZIG_STD_FILES
"os/windows/util.zig"
"os/zen.zig"
"pdb.zig"
"priority_queue.zig"
"rand/index.zig"
"rand/ziggurat.zig"
"segmented_list.zig"
Expand Down
4 changes: 3 additions & 1 deletion std/index.zig
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub const DynLib = @import("dynamic_library.zig").DynLib;
pub const HashMap = @import("hash_map.zig").HashMap;
pub const LinkedList = @import("linked_list.zig").LinkedList;
pub const Mutex = @import("mutex.zig").Mutex;
pub const PriorityQueue = @import("priority_queue.zig").PriorityQueue;
pub const StaticallyInitializedMutex = @import("statically_initialized_mutex.zig").StaticallyInitializedMutex;
pub const SegmentedList = @import("segmented_list.zig").SegmentedList;
pub const SpinLock = @import("spinlock.zig").SpinLock;
Expand Down Expand Up @@ -59,7 +60,7 @@ test "std" {
_ = @import("statically_initialized_mutex.zig");
_ = @import("segmented_list.zig");
_ = @import("spinlock.zig");

_ = @import("base64.zig");
_ = @import("build.zig");
_ = @import("c/index.zig");
Expand All @@ -85,6 +86,7 @@ test "std" {
_ = @import("net.zig");
_ = @import("os/index.zig");
_ = @import("pdb.zig");
_ = @import("priority_queue.zig");
_ = @import("rand/index.zig");
_ = @import("sort.zig");
_ = @import("testing.zig");
Expand Down
331 changes: 331 additions & 0 deletions std/priority_queue.zig
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
const std = @import("index.zig");
const Allocator = std.mem.Allocator;
const debug = std.debug;
const expect = std.testing.expect;
const expectEqual = std.testing.expectEqual;

pub fn PriorityQueue(comptime T: type) type {
return struct {
const Self = @This();

items: []T,
len: usize,
allocator: *Allocator,
compareFn: fn (a: T, b: T) bool,

pub fn init(allocator: *Allocator, compareFn: fn (a: T, b: T) bool) Self {
return Self{
.items = []T{},
.len = 0,
.allocator = allocator,
.compareFn = compareFn,
};
}

pub fn deinit(self: Self) void {
self.allocator.free(self.items);
}

pub fn add(self: *Self, elem: T) !void {
try ensureCapacity(self, self.len + 1);

self.items[self.len] = elem;
var child_index = self.len;
while (child_index > 0) {
var parent_index = ((child_index - 1) >> 1);
const child = self.items[child_index];
const parent = self.items[parent_index];

if (!self.compareFn(child, parent)) break;

self.items[parent_index] = child;
self.items[child_index] = parent;
child_index = parent_index;
}
self.len += 1;
}

pub fn peek(self: *Self) ?T {
return if (self.len > 0) self.items[0] else null;
}

pub fn removeOrNull(self: *Self) ?T {
return if (self.len > 0) self.remove() else null;
}

pub fn remove(self: *Self) T {
const first = self.items[0];
const last = self.items[self.len - 1];
self.items[0] = last;
self.len -= 1;
siftDown(self);
return first;
}

pub fn count(self: Self) usize {
return self.len;
}

pub fn capacity(self: Self) usize {
return self.items.len;
}

fn siftDown(self: *Self) void {
var index: usize = 0;
const half = self.len >> 1;
while (true) {
var left_index = (index << 1) + 1;
var right_index = left_index + 1;
var left = if (left_index < self.len) self.items[left_index] else null;
var right = if (right_index < self.len) self.items[right_index] else null;

var smallest_index = index;
var smallest = self.items[index];

if (left) |e| {
if (self.compareFn(e, smallest)) {
smallest_index = left_index;
smallest = e;
}
}

if (right) |e| {
if (self.compareFn(e, smallest)) {
smallest_index = right_index;
smallest = e;
}
}

if (smallest_index == index) return;

self.items[smallest_index] = self.items[index];
self.items[index] = smallest;
index = smallest_index;

if (index >= half) return;
}
}

pub fn ensureCapacity(self: *Self, new_capacity: usize) !void {
var better_capacity = self.capacity();
if (better_capacity >= new_capacity) return;
while (true) {
better_capacity += better_capacity / 2 + 8;
if (better_capacity >= new_capacity) break;
}
self.items = try self.allocator.realloc(T, self.items, better_capacity);
}

pub fn resize(self: *Self, new_len: usize) !void {
try self.ensureCapacity(new_len);
self.len = new_len;
}

pub fn shrink(self: *Self, new_len: usize) void {
assert(new_len <= self.len);
self.len = new_len;
}

const Iterator = struct {
queue: *PriorityQueue(T),
count: usize,

fn next(it: *Iterator) ?T {
if (it.count > it.queue.len - 1) return null;
const out = it.count;
it.count += 1;
return it.queue.items[out];
}

fn reset(it: *Iterator) void {
it.count = 0;
}
};

pub fn iterator(self: *Self) Iterator {
return Iterator{
.queue = self,
.count = 0,
};
}

fn dump(self: *Self) void {
warn("{{ ");
warn("items: ");
for (self.items) |e, i| {
if (i >= self.len) break;
warn("{}, ", e);
}
warn("array: ");
for (self.items) |e, i| {
warn("{}, ", e);
}
warn("len: {} ", self.len);
warn("capacity: {}", self.capacity());
warn(" }}\n");
}
};
}

fn lessThan(a: u32, b: u32) bool {
return a < b;
}

fn greaterThan(a: u32, b: u32) bool {
return a > b;
}

const PQ = PriorityQueue(u32);

test "std.PriorityQueue: add and remove min heap" {
var queue = PQ.init(debug.global_allocator, lessThan);
defer queue.deinit();

try queue.add(54);
try queue.add(12);
try queue.add(7);
try queue.add(23);
try queue.add(25);
try queue.add(13);
expectEqual(u32(7), queue.remove());
expectEqual(u32(12), queue.remove());
expectEqual(u32(13), queue.remove());
expectEqual(u32(23), queue.remove());
expectEqual(u32(25), queue.remove());
expectEqual(u32(54), queue.remove());
}

test "std.PriorityQueue: add and remove same min heap" {
var queue = PQ.init(debug.global_allocator, lessThan);
defer queue.deinit();

try queue.add(1);
try queue.add(1);
try queue.add(2);
try queue.add(2);
try queue.add(1);
try queue.add(1);
expectEqual(u32(1), queue.remove());
expectEqual(u32(1), queue.remove());
expectEqual(u32(1), queue.remove());
expectEqual(u32(1), queue.remove());
expectEqual(u32(2), queue.remove());
expectEqual(u32(2), queue.remove());
}

test "std.PriorityQueue: removeOrNull on empty" {
var queue = PQ.init(debug.global_allocator, lessThan);
defer queue.deinit();

expect(queue.removeOrNull() == null);
}

test "std.PriorityQueue: edge case 3 elements" {
var queue = PQ.init(debug.global_allocator, lessThan);
defer queue.deinit();

try queue.add(9);
try queue.add(3);
try queue.add(2);
expectEqual(u32(2), queue.remove());
expectEqual(u32(3), queue.remove());
expectEqual(u32(9), queue.remove());
}

test "std.PriorityQueue: peek" {
var queue = PQ.init(debug.global_allocator, lessThan);
defer queue.deinit();

expect(queue.peek() == null);
try queue.add(9);
try queue.add(3);
try queue.add(2);
expectEqual(u32(2), queue.peek().?);
expectEqual(u32(2), queue.peek().?);
}

test "std.PriorityQueue: sift up with odd indices" {
var queue = PQ.init(debug.global_allocator, lessThan);
defer queue.deinit();
const items = []u32{ 15, 7, 21, 14, 13, 22, 12, 6, 7, 25, 5, 24, 11, 16, 15, 24, 2, 1 };
for (items) |e| {
try queue.add(e);
}

expectEqual(u32(1), queue.remove());
expectEqual(u32(2), queue.remove());
expectEqual(u32(5), queue.remove());
expectEqual(u32(6), queue.remove());
expectEqual(u32(7), queue.remove());
expectEqual(u32(7), queue.remove());
expectEqual(u32(11), queue.remove());
expectEqual(u32(12), queue.remove());
expectEqual(u32(13), queue.remove());
expectEqual(u32(14), queue.remove());
expectEqual(u32(15), queue.remove());
expectEqual(u32(15), queue.remove());
expectEqual(u32(16), queue.remove());
expectEqual(u32(21), queue.remove());
expectEqual(u32(22), queue.remove());
expectEqual(u32(24), queue.remove());
expectEqual(u32(24), queue.remove());
expectEqual(u32(25), queue.remove());
}

test "std.PriorityQueue: add and remove max heap" {
var queue = PQ.init(debug.global_allocator, greaterThan);
defer queue.deinit();

try queue.add(54);
try queue.add(12);
try queue.add(7);
try queue.add(23);
try queue.add(25);
try queue.add(13);
expectEqual(u32(54), queue.remove());
expectEqual(u32(25), queue.remove());
expectEqual(u32(23), queue.remove());
expectEqual(u32(13), queue.remove());
expectEqual(u32(12), queue.remove());
expectEqual(u32(7), queue.remove());
}

test "std.PriorityQueue: add and remove same max heap" {
var queue = PQ.init(debug.global_allocator, greaterThan);
defer queue.deinit();

try queue.add(1);
try queue.add(1);
try queue.add(2);
try queue.add(2);
try queue.add(1);
try queue.add(1);
expectEqual(u32(2), queue.remove());
expectEqual(u32(2), queue.remove());
expectEqual(u32(1), queue.remove());
expectEqual(u32(1), queue.remove());
expectEqual(u32(1), queue.remove());
expectEqual(u32(1), queue.remove());
}

test "std.PriorityQueue: iterator" {
var queue = PQ.init(debug.global_allocator, lessThan);
var map = std.AutoHashMap(u32, void).init(debug.global_allocator);
defer {
queue.deinit();
map.deinit();
}

const items = []u32{ 54, 12, 7, 23, 25, 13 };
for (items) |e| {
_ = try queue.add(e);
_ = try map.put(e, {});
}

var it = queue.iterator();
while (it.next()) |e| {
_ = map.remove(e);
}

expectEqual(usize(0), map.count());
}