diff --git a/src/traitor.zig b/src/traitor.zig index d987777..24a70ae 100644 --- a/src/traitor.zig +++ b/src/traitor.zig @@ -188,10 +188,31 @@ const ErrorCode = enum(u8) { MissingDeclaration = 6, MissingFunction = 7, MissingField = 8, - MetaDataHasIncorrectType = 9, + TraitMetaDataHasIncorrectType = 9, + IllegalUseOfTraitorInternalDecl = 10, + TraitMissingGenericTypeDeclaration = 11, + TraitGenericTypeNotAType = 12, + TraitGenericTypeStructNotAuto = 13, + + // https://github.com/ziglang/zig/issues/6709 + TraitGenericTypeStructHasDecls = 14, }; +// Constants +const eval_branch_quota = 1_000_000; + +// Associated Types Helpers +pub const GenericSelf = struct {}; + +const associated_type_decl_identifier = "__traitor_internal_associated_type_decl_name"; +pub fn AssociatedType(comptime decl_name: []const u8) type { + return struct { + const __traitor_internal_associated_type_decl_name = decl_name; + }; +} + pub fn checkTrait(comptime Trait: type, comptime T: type) void { + @setEvalBranchQuota(eval_branch_quota); //-------------------------------------------------- // Error Messaging Setup @@ -270,7 +291,7 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void { if (isCoercibleToString(@TypeOf(value))) { trait_name = value; } else { - printError("The type of the trait's '{s}' declaration must be compatible with '[]const u8', found '{s}' instead.", &error_writer, .MetaDataHasIncorrectType, .{ + printError("The type of the trait's '{s}' declaration must be compatible with '[]const u8', found '{s}' instead.", &error_writer, .TraitMetaDataHasIncorrectType, .{ meta_trait_name, @typeName(@TypeOf(value)), }); @@ -286,6 +307,13 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void { // Check if T satisfies the trait boundary //-------------------------------------------------- + var cache = BufferedTypeTypeHashMap{}; + + var ctx = Context{ + .writer = &error_writer, + .cache = &cache, + }; + error_writer.print( \\The type '{s}' does not satisfy the trait bounds of trait '{s}' due to the following errors: \\ @@ -301,7 +329,8 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void { continue; } - const trait_decl_type = @TypeOf(@field(Trait, trait_decl_name)); + const trait_decl_type_raw = @TypeOf(@field(Trait, trait_decl_name)); + const trait_decl_type = SubstitutedType(ctx, trait_decl_type_raw, Trait, T, trait_decl_type_raw); if (@hasDecl(T, trait_decl_name)) { const t_decl_type = @TypeOf(@field(T, trait_decl_name)); @@ -321,7 +350,8 @@ pub fn checkTrait(comptime Trait: type, comptime T: type) void { // Fields inline for (trait_struct_info.fields) |trait_field| { const trait_field_name = trait_field.name; - const trait_field_type = trait_field.type; + const trait_field_type_raw = trait_field.type; + const trait_field_type = SubstitutedType(ctx, trait_field_type_raw, Trait, T, trait_field_type_raw); if (@hasField(T, trait_field_name)) { const dummy: T = undefined; @@ -423,3 +453,224 @@ fn isCoercibleToString(comptime T: type) bool { return false; } + +const BufferedTypeTypeHashMap = struct { + const N = 1024; + + keys: [N]?type = [_]?type{null} ** N, + values: [N]?type = [_]?type{null} ** N, + num_occupied: usize = 0, + num_failed_to_write: usize = 0, + num_cache_hits: usize = 0, + num_cache_misses: usize = 0, + + fn hashString(s: []const u8) usize { + return @intCast(std.hash.Wyhash.hash(0, s) & (N - 1)); + } + + fn getValue(comptime self: *@This(), comptime key: type) ?type { + const hash_index = hashString(@typeName(key)); + + const key_at_index = self.keys[hash_index]; + + if (key_at_index != null and key_at_index.? == key) { + self.num_cache_hits += 1; + return self.values[hash_index]; + } + + self.num_cache_misses += 1; + + return null; + } + + fn setKV(comptime self: *@This(), comptime key: type, comptime value: type) void { + const hash_index = hashString(@typeName(key)); + + const key_at_index = self.keys[hash_index]; + + if (key_at_index == null) { + self.keys[hash_index] = key; + self.values[hash_index] = value; + self.num_occupied += 1; + } else { + self.num_failed_to_write += 1; + } + } +}; + +const Context = struct { + writer: *Writer, + cache: *BufferedTypeTypeHashMap, +}; + +fn SubstitutedType(comptime ctx: Context, comptime pattern: type, comptime Trait: type, comptime T: type, comptime original_pattern: type) type { + if (pattern == GenericSelf) + return T; + + // Immediately return basic types without going through the cache + const pattern_info = @typeInfo(pattern); + switch (pattern_info) { + .Pointer, .Array, .Struct, .Optional, .ErrorUnion, .ErrorSet, .Enum, .Union, .Fn, .Vector => {}, + else => return pattern, + } + + // Try to find it in the cache + const cached_type = ctx.cache.getValue(pattern); + + if (cached_type) |ctype| { + @compileLog("we had", pattern, "cached! it's", ctype); + return ctype; + } + + // Do actual substitution + switch (pattern_info) { + .Struct => |strct| { + + // Deal with AssociatedType(S) + if (@hasDecl(pattern, associated_type_decl_identifier)) { + // Get the name of the declaration that specifies the associated type + const associated_type_decl_name = @field(pattern, associated_type_decl_identifier); + + // Check that associated_type_decl_name is in fact a string + if (@TypeOf(associated_type_decl_name) != []const u8) { + printError("Illegal use of `{s}` declaration in trait.", ctx.writer, .IllegalUseOfTraitorInternalDecl, .{associated_type_decl_identifier}); + + return pattern; + } + + // Check that + // 1. This declaration is part of the trait specification + if (!@hasDecl(Trait, associated_type_decl_name)) { + printError("Expected declaration of generic type '{s}' in trait.", ctx.writer, .TraitMissingGenericTypeDeclaration, .{associated_type_decl_name}); + + return pattern; + } + + // 2. The declaration is of type type + if (@TypeOf(@field(Trait, associated_type_decl_name)) != type) { + printError("Expected declaration of generic type '{s}' to be of type 'type', got '{s}' instead.", ctx.writer, .TraitGenericTypeNotAType, .{ + associated_type_decl_name, + @typeName(@TypeOf(@field(Trait, associated_type_decl_name))), + }); + + return pattern; + } + + // 3. Check that T has the declaration + if (!@hasDecl(T, associated_type_decl_name) or @TypeOf(@field(T, associated_type_decl_name)) != type) { + // Do not throw an error, as it will already be reported elsewhere. + return pattern; + } + + // Everything is fine. Return the associated type. + const resolved_type = @field(T, associated_type_decl_name); + ctx.cache.setKV(pattern, resolved_type); + return resolved_type; + } + + // This is the general case + + // Check that we don't do weird stuff. If the layout is not Auto, + // we don't allow generic shenanigans. + const may_be_generic = strct.layout == .Auto; + + var is_generic = false; + + var modified_info = strct; + + var field_buffer: [strct.fields.len]Type.StructField = undefined; + + for (strct.fields, 0..) |field, i| { + const subst_type = SubstitutedType(ctx, field.type, Trait, T, original_pattern); + + if (field.type != subst_type) { + is_generic = true; + } + + field_buffer[i] = .{ + .name = field.name, + .type = subst_type, + .default_value = null, + .is_comptime = field.is_comptime, + .alignment = @alignOf(subst_type), + }; + } + + modified_info.fields = &field_buffer; + + if (!may_be_generic and is_generic) { + printError("Structs making use of GATs must have automatic layout. Found issue in '{s}'.", ctx.writer, .TraitGenericTypeStructNotAuto, .{ + @typeName(original_pattern), + }); + + return pattern; + } + + // https://github.com/ziglang/zig/issues/6709 + if (is_generic and strct.decls.len != 0) { + printError("Structs making use of GATs must not have declarations. Found issue in '{s}'.", ctx.writer, .TraitGenericTypeStructHasDecls, .{ + @typeName(original_pattern), + }); + + return pattern; + } + + const out_type = @Type(.{ .Struct = modified_info }); + ctx.cache.setKV(pattern, out_type); + return out_type; + }, + .Fn => |func| { + var modified_info = func; + var param_buffer: [func.params.len]Type.Fn.Param = undefined; + + for (func.params, 0..) |param, i| { + param_buffer[i] = param; + + if (param_buffer[i].type != null) { + param_buffer[i].type = SubstitutedType(ctx, param.type.?, Trait, T, original_pattern); + } + } + + modified_info.return_type = if (func.return_type == null) null else SubstitutedType(ctx, func.return_type.?, Trait, T, original_pattern); + modified_info.params = ¶m_buffer; + + const out_type = @Type(.{ .Fn = modified_info }); + ctx.cache.setKV(pattern, out_type); + return out_type; + }, + .Pointer => |ptr| { + var modified_info = ptr; + modified_info.child = SubstitutedType(ctx, ptr.child, Trait, T, original_pattern); + + return @Type(.{ .Pointer = modified_info }); + }, + .Array => |arr| { + var modified_info = arr; + modified_info.child = SubstitutedType(ctx, arr.child, Trait, T, original_pattern); + + return @Type(.{ .Array = modified_info }); + }, + .Optional => |opt| { + var modified_info = opt; + modified_info.child = SubstitutedType(ctx, opt.child, Trait, T, original_pattern); + + return @Type(.{ .Optional = modified_info }); + }, + .ErrorUnion => |erruni| { + var modified_info = erruni; + modified_info.payload = SubstitutedType(ctx, erruni.payload, Trait, T, original_pattern); + + return @Type(.{ .ErrorUnion = modified_info }); + }, + .Vector => |vec| { + var modified_info = vec; + modified_info.child = SubstitutedType(ctx, vec.child, Trait, T, original_pattern); + + return @Type(.{ .Vector = modified_info }); + }, + .Enum, .Union => {}, // TODO: we first need caching + else => unreachable, + } + + return pattern; +}