Skip to content

Commit 126180c

Browse files
committed
Sema: implement @errSetCast for error unions
Closes ziglang#17343
1 parent 376242e commit 126180c

File tree

3 files changed

+62
-20
lines changed

3 files changed

+62
-20
lines changed

doc/langref.html.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -8413,7 +8413,7 @@ test "main" {
84138413
{#header_open|@errSetCast#}
84148414
<pre>{#syntax#}@errSetCast(value: anytype) anytype{#endsyntax#}</pre>
84158415
<p>
8416-
Converts an error value from one error set to another error set. The return type is the
8416+
Converts an error set or error union value from one error set to another error set. The return type is the
84178417
inferred result type. Attempting to convert an error which is not in the destination error
84188418
set results in safety-protected {#link|Undefined Behavior#}.
84198419
</p>

src/Sema.zig

+45-19
Original file line numberDiff line numberDiff line change
@@ -21753,11 +21753,25 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
2175321753
const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
2175421754
const src = LazySrcLoc.nodeOffset(extra.node);
2175521755
const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
21756-
const dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_eu_opt, "@errSetCast");
21756+
const base_dest_ty = try sema.resolveDestType(block, src, extra.lhs, .remove_opt, "@errSetCast");
2175721757
const operand = try sema.resolveInst(extra.rhs);
21758-
const operand_ty = sema.typeOf(operand);
21759-
try sema.checkErrorSetType(block, src, dest_ty);
21760-
try sema.checkErrorSetType(block, operand_src, operand_ty);
21758+
const base_operand_ty = sema.typeOf(operand);
21759+
const dest_tag = base_dest_ty.zigTypeTag(mod);
21760+
const operand_tag = base_operand_ty.zigTypeTag(mod);
21761+
if (dest_tag != operand_tag) {
21762+
return sema.fail(block, src, "expected source and destination types to match, found '{s}' and '{s}'", .{
21763+
@tagName(operand_tag), @tagName(dest_tag),
21764+
});
21765+
} else if (dest_tag != .ErrorSet and dest_tag != .ErrorUnion) {
21766+
return sema.fail(block, src, "expected error set or error union type, found '{s}'", .{@tagName(dest_tag)});
21767+
}
21768+
const dest_ty, const operand_ty = if (dest_tag == .ErrorUnion) .{
21769+
base_dest_ty.errorUnionSet(mod),
21770+
base_operand_ty.errorUnionSet(mod),
21771+
} else .{
21772+
base_dest_ty,
21773+
base_operand_ty,
21774+
};
2176121775

2176221776
// operand must be defined since it can be an invalid error value
2176321777
const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand);
@@ -21804,8 +21818,15 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
2180421818
}
2180521819

2180621820
if (maybe_operand_val) |val| {
21807-
if (!dest_ty.isAnyError(mod)) {
21808-
const error_name = mod.intern_pool.indexToKey(val.toIntern()).err.name;
21821+
if (!dest_ty.isAnyError(mod)) check: {
21822+
const operand_val = mod.intern_pool.indexToKey(val.toIntern());
21823+
var error_name: InternPool.NullTerminatedString = undefined;
21824+
if (dest_tag == .ErrorUnion) {
21825+
if (operand_val.error_union.val != .err_name) break :check;
21826+
error_name = operand_val.error_union.val.err_name;
21827+
} else {
21828+
error_name = operand_val.err.name;
21829+
}
2180921830
if (!Type.errorSetHasFieldIp(ip, dest_ty.toIntern(), error_name)) {
2181021831
const msg = msg: {
2181121832
const msg = try sema.errMsg(
@@ -21822,16 +21843,29 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
2182221843
}
2182321844
}
2182421845

21825-
return Air.internedToRef((try mod.getCoerced(val, dest_ty)).toIntern());
21846+
return Air.internedToRef((try mod.getCoerced(val, base_dest_ty)).toIntern());
2182621847
}
2182721848

2182821849
try sema.requireRuntimeBlock(block, src, operand_src);
2182921850
if (block.wantSafety() and !dest_ty.isAnyError(mod) and sema.mod.backendSupportsFeature(.error_set_has_value)) {
21830-
const err_int_inst = try block.addBitCast(Type.err_int, operand);
21831-
const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
21832-
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
21851+
if (dest_tag == .ErrorUnion) {
21852+
const err_code = try sema.analyzeErrUnionCode(block, operand_src, operand);
21853+
const err_int = try block.addBitCast(Type.err_int, err_code);
21854+
const zero_u16 = Air.internedToRef(try mod.intern(.{
21855+
.int = .{ .ty = .u16_type, .storage = .{ .u64 = 0 } },
21856+
}));
21857+
21858+
const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code);
21859+
const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_u16);
21860+
const ok = try block.addBinOp(.bit_or, has_value, is_zero);
21861+
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
21862+
} else {
21863+
const err_int_inst = try block.addBitCast(Type.err_int, operand);
21864+
const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst);
21865+
try sema.addSafetyCheck(block, src, ok, .invalid_error_code);
21866+
}
2183321867
}
21834-
return block.addBitCast(dest_ty, operand);
21868+
return block.addBitCast(base_dest_ty, operand);
2183521869
}
2183621870

2183721871
fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
@@ -22916,14 +22950,6 @@ fn checkIntOrVectorAllowComptime(
2291622950
}
2291722951
}
2291822952

22919-
fn checkErrorSetType(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) CompileError!void {
22920-
const mod = sema.mod;
22921-
switch (ty.zigTypeTag(mod)) {
22922-
.ErrorSet => return,
22923-
else => return sema.fail(block, src, "expected error set type, found '{}'", .{ty.fmt(mod)}),
22924-
}
22925-
}
22926-
2292722953
const SimdBinOp = struct {
2292822954
len: ?usize,
2292922955
/// Coerced to `result_ty`.

test/behavior/error.zig

+16
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,22 @@ fn testExplicitErrorSetCast(set1: Set1) !void {
235235
try expect(y == error.A);
236236
}
237237

238+
test "@errSetCast on error unions" {
239+
const S = struct {
240+
fn doTheTest() !void {
241+
const casted: error{Bad}!i32 = @errSetCast(retErrUnion());
242+
try expect((try casted) == 1234);
243+
}
244+
245+
fn retErrUnion() anyerror!i32 {
246+
return 1234;
247+
}
248+
};
249+
250+
try S.doTheTest();
251+
try comptime S.doTheTest();
252+
}
253+
238254
test "comptime test error for empty error set" {
239255
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
240256

0 commit comments

Comments
 (0)