Skip to content

Commit

Permalink
traitor: Add associated type support
Browse files Browse the repository at this point in the history
GATs currently do not work in zig 0.11 due to
ziglang/zig#6709.
  • Loading branch information
greytdepression committed Aug 28, 2023
1 parent ee27e9e commit cc85290
Showing 1 changed file with 255 additions and 4 deletions.
259 changes: 255 additions & 4 deletions src/traitor.zig
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,31 @@ const ErrorCode = enum(u8) {
MissingDeclaration = 6,
MissingFunction = 7,
MissingField = 8,
MetaDataHasIncorrectType = 9,
TraitMetaDataHasIncorrectType = 9,
IllegalUseOfTraitorInternalDecl = 10,
TraitMissingGenericTypeDeclaration = 11,
TraitGenericTypeNotAType = 12,
TraitGenericTypeStructNotAuto = 13,

// https://github.com/ziglang/zig/issues/6709
TraitGenericTypeStructHasDecls = 14,
};

// Constants
const eval_branch_quota = 1_000_000;

// Associated Types Helpers
pub const GenericSelf = struct {};

const associated_type_decl_identifier = "__traitor_internal_associated_type_decl_name";
pub fn AssociatedType(comptime decl_name: []const u8) type {
return struct {
const __traitor_internal_associated_type_decl_name = decl_name;
};
}

pub fn checkTrait(comptime Trait: type, comptime T: type) void {
@setEvalBranchQuota(eval_branch_quota);

//--------------------------------------------------
// Error Messaging Setup
Expand Down Expand Up @@ -270,7 +291,7 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void {
if (isCoercibleToString(@TypeOf(value))) {
trait_name = value;
} else {
printError("The type of the trait's '{s}' declaration must be compatible with '[]const u8', found '{s}' instead.", &error_writer, .MetaDataHasIncorrectType, .{
printError("The type of the trait's '{s}' declaration must be compatible with '[]const u8', found '{s}' instead.", &error_writer, .TraitMetaDataHasIncorrectType, .{
meta_trait_name,
@typeName(@TypeOf(value)),
});
Expand All @@ -286,6 +307,13 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void {
// Check if T satisfies the trait boundary
//--------------------------------------------------

var cache = BufferedTypeTypeHashMap{};

var ctx = Context{
.writer = &error_writer,
.cache = &cache,
};

error_writer.print(
\\The type '{s}' does not satisfy the trait bounds of trait '{s}' due to the following errors:
\\
Expand All @@ -301,7 +329,8 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void {
continue;
}

const trait_decl_type = @TypeOf(@field(Trait, trait_decl_name));
const trait_decl_type_raw = @TypeOf(@field(Trait, trait_decl_name));
const trait_decl_type = SubstitutedType(ctx, trait_decl_type_raw, Trait, T, trait_decl_type_raw);

if (@hasDecl(T, trait_decl_name)) {
const t_decl_type = @TypeOf(@field(T, trait_decl_name));
Expand All @@ -321,7 +350,8 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void {
// Fields
inline for (trait_struct_info.fields) |trait_field| {
const trait_field_name = trait_field.name;
const trait_field_type = trait_field.type;
const trait_field_type_raw = trait_field.type;
const trait_field_type = SubstitutedType(ctx, trait_field_type_raw, Trait, T, trait_field_type_raw);

if (@hasField(T, trait_field_name)) {
const dummy: T = undefined;
Expand Down Expand Up @@ -423,3 +453,224 @@ fn isCoercibleToString(comptime T: type) bool {

return false;
}

const BufferedTypeTypeHashMap = struct {
const N = 1024;

keys: [N]?type = [_]?type{null} ** N,
values: [N]?type = [_]?type{null} ** N,
num_occupied: usize = 0,
num_failed_to_write: usize = 0,
num_cache_hits: usize = 0,
num_cache_misses: usize = 0,

fn hashString(s: []const u8) usize {
return @intCast(std.hash.Wyhash.hash(0, s) & (N - 1));
}

fn getValue(comptime self: *@This(), comptime key: type) ?type {
const hash_index = hashString(@typeName(key));

const key_at_index = self.keys[hash_index];

if (key_at_index != null and key_at_index.? == key) {
self.num_cache_hits += 1;
return self.values[hash_index];
}

self.num_cache_misses += 1;

return null;
}

fn setKV(comptime self: *@This(), comptime key: type, comptime value: type) void {
const hash_index = hashString(@typeName(key));

const key_at_index = self.keys[hash_index];

if (key_at_index == null) {
self.keys[hash_index] = key;
self.values[hash_index] = value;
self.num_occupied += 1;
} else {
self.num_failed_to_write += 1;
}
}
};

const Context = struct {
writer: *Writer,
cache: *BufferedTypeTypeHashMap,
};

fn SubstitutedType(comptime ctx: Context, comptime pattern: type, comptime Trait: type, comptime T: type, comptime original_pattern: type) type {
if (pattern == GenericSelf)
return T;

// Immediately return basic types without going through the cache
const pattern_info = @typeInfo(pattern);
switch (pattern_info) {
.Pointer, .Array, .Struct, .Optional, .ErrorUnion, .ErrorSet, .Enum, .Union, .Fn, .Vector => {},
else => return pattern,
}

// Try to find it in the cache
const cached_type = ctx.cache.getValue(pattern);

if (cached_type) |ctype| {
@compileLog("we had", pattern, "cached! it's", ctype);
return ctype;
}

// Do actual substitution
switch (pattern_info) {
.Struct => |strct| {

// Deal with AssociatedType(S)
if (@hasDecl(pattern, associated_type_decl_identifier)) {
// Get the name of the declaration that specifies the associated type
const associated_type_decl_name = @field(pattern, associated_type_decl_identifier);

// Check that associated_type_decl_name is in fact a string
if (@TypeOf(associated_type_decl_name) != []const u8) {
printError("Illegal use of `{s}` declaration in trait.", ctx.writer, .IllegalUseOfTraitorInternalDecl, .{associated_type_decl_identifier});

return pattern;
}

// Check that
// 1. This declaration is part of the trait specification
if (!@hasDecl(Trait, associated_type_decl_name)) {
printError("Expected declaration of generic type '{s}' in trait.", ctx.writer, .TraitMissingGenericTypeDeclaration, .{associated_type_decl_name});

return pattern;
}

// 2. The declaration is of type type
if (@TypeOf(@field(Trait, associated_type_decl_name)) != type) {
printError("Expected declaration of generic type '{s}' to be of type 'type', got '{s}' instead.", ctx.writer, .TraitGenericTypeNotAType, .{
associated_type_decl_name,
@typeName(@TypeOf(@field(Trait, associated_type_decl_name))),
});

return pattern;
}

// 3. Check that T has the declaration
if (!@hasDecl(T, associated_type_decl_name) or @TypeOf(@field(T, associated_type_decl_name)) != type) {
// Do not throw an error, as it will already be reported elsewhere.
return pattern;
}

// Everything is fine. Return the associated type.
const resolved_type = @field(T, associated_type_decl_name);
ctx.cache.setKV(pattern, resolved_type);
return resolved_type;
}

// This is the general case

// Check that we don't do weird stuff. If the layout is not Auto,
// we don't allow generic shenanigans.
const may_be_generic = strct.layout == .Auto;

var is_generic = false;

var modified_info = strct;

var field_buffer: [strct.fields.len]Type.StructField = undefined;

for (strct.fields, 0..) |field, i| {
const subst_type = SubstitutedType(ctx, field.type, Trait, T, original_pattern);

if (field.type != subst_type) {
is_generic = true;
}

field_buffer[i] = .{
.name = field.name,
.type = subst_type,
.default_value = null,
.is_comptime = field.is_comptime,
.alignment = @alignOf(subst_type),
};
}

modified_info.fields = &field_buffer;

if (!may_be_generic and is_generic) {
printError("Structs making use of GATs must have automatic layout. Found issue in '{s}'.", ctx.writer, .TraitGenericTypeStructNotAuto, .{
@typeName(original_pattern),
});

return pattern;
}

// https://github.com/ziglang/zig/issues/6709
if (is_generic and strct.decls.len != 0) {
printError("Structs making use of GATs must not have declarations. Found issue in '{s}'.", ctx.writer, .TraitGenericTypeStructHasDecls, .{
@typeName(original_pattern),
});

return pattern;
}

const out_type = @Type(.{ .Struct = modified_info });
ctx.cache.setKV(pattern, out_type);
return out_type;
},
.Fn => |func| {
var modified_info = func;
var param_buffer: [func.params.len]Type.Fn.Param = undefined;

for (func.params, 0..) |param, i| {
param_buffer[i] = param;

if (param_buffer[i].type != null) {
param_buffer[i].type = SubstitutedType(ctx, param.type.?, Trait, T, original_pattern);
}
}

modified_info.return_type = if (func.return_type == null) null else SubstitutedType(ctx, func.return_type.?, Trait, T, original_pattern);
modified_info.params = &param_buffer;

const out_type = @Type(.{ .Fn = modified_info });
ctx.cache.setKV(pattern, out_type);
return out_type;
},
.Pointer => |ptr| {
var modified_info = ptr;
modified_info.child = SubstitutedType(ctx, ptr.child, Trait, T, original_pattern);

return @Type(.{ .Pointer = modified_info });
},
.Array => |arr| {
var modified_info = arr;
modified_info.child = SubstitutedType(ctx, arr.child, Trait, T, original_pattern);

return @Type(.{ .Array = modified_info });
},
.Optional => |opt| {
var modified_info = opt;
modified_info.child = SubstitutedType(ctx, opt.child, Trait, T, original_pattern);

return @Type(.{ .Optional = modified_info });
},
.ErrorUnion => |erruni| {
var modified_info = erruni;
modified_info.payload = SubstitutedType(ctx, erruni.payload, Trait, T, original_pattern);

return @Type(.{ .ErrorUnion = modified_info });
},
.Vector => |vec| {
var modified_info = vec;
modified_info.child = SubstitutedType(ctx, vec.child, Trait, T, original_pattern);

return @Type(.{ .Vector = modified_info });
},
.Enum, .Union => {}, // TODO: we first need caching
else => unreachable,
}

return pattern;
}

0 comments on commit cc85290

Please sign in to comment.