stage2: implement union coercion to its own tag

* AIR: add `get_union_tag` instruction
   - implement in LLVM backend
 * Sema: implement == and != for union and enum literal
   - Also implement coercion from union to its own tag type
 * Value: implement hashing for union values

The motivating example is this snippet:

    comptime assert(@typeInfo(T) == .Float);

This was the next blocker for stage2 building compiler-rt.
Now it is switch at compile-time on an integer.
This commit is contained in:
Andrew Kelley 2021-09-27 23:11:00 -07:00
parent c2a7542df5
commit 09e1f37cb6
11 changed files with 162 additions and 33 deletions

View File

@ -290,6 +290,9 @@ pub const Inst = struct {
/// Result type is always void.
/// Uses the `bin_op` field. LHS is union pointer, RHS is new tag value.
set_union_tag,
/// Given a tagged union value, get its tag value.
/// Uses the `ty_op` field.
get_union_tag,
/// Given a slice value, return the length.
/// Result type is always usize.
/// Uses the `ty_op` field.
@ -630,6 +633,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
.array_to_slice,
.float_to_int,
.int_to_float,
.get_union_tag,
=> return air.getRefType(datas[inst].ty_op.ty),
.loop,

View File

@ -297,6 +297,7 @@ fn analyzeInst(
.array_to_slice,
.float_to_int,
.int_to_float,
.get_union_tag,
=> {
const o = inst_datas[inst].ty_op;
return trackOperands(a, new_set, inst, main_tomb, .{ o.operand, .none, .none });

View File

@ -1349,7 +1349,13 @@ fn zirUnionDecl(
errdefer new_decl_arena.deinit();
const union_obj = try new_decl_arena.allocator.create(Module.Union);
const union_ty = try Type.Tag.@"union".create(&new_decl_arena.allocator, union_obj);
const type_tag: Type.Tag = if (small.has_tag_type or small.auto_enum_tag) .union_tagged else .@"union";
const union_payload = try new_decl_arena.allocator.create(Type.Payload.Union);
union_payload.* = .{
.base = .{ .tag = type_tag },
.data = union_obj,
};
const union_ty = Type.initPayload(&union_payload.base);
const union_val = try Value.Tag.ty.create(&new_decl_arena.allocator, union_ty);
const type_name = try sema.createTypeName(block, small.name_strategy);
const new_decl = try sema.mod.createAnonymousDeclNamed(&block.base, .{
@ -6477,10 +6483,11 @@ fn zirCmpEq(
const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
}
if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
(rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
{
return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
if (lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) {
return sema.analyzeCmpUnionTag(block, rhs, rhs_src, lhs, lhs_src, op);
}
if (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union) {
return sema.analyzeCmpUnionTag(block, lhs, lhs_src, rhs, rhs_src, op);
}
if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
const runtime_src: LazySrcLoc = src: {
@ -6521,6 +6528,28 @@ fn zirCmpEq(
return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true);
}
fn analyzeCmpUnionTag(
sema: *Sema,
block: *Scope.Block,
un: Air.Inst.Ref,
un_src: LazySrcLoc,
tag: Air.Inst.Ref,
tag_src: LazySrcLoc,
op: std.math.CompareOperator,
) CompileError!Air.Inst.Ref {
const union_ty = sema.typeOf(un);
const union_tag_ty = union_ty.unionTagType() orelse {
// TODO note at declaration site that says "union foo is not tagged"
return sema.mod.fail(&block.base, un_src, "comparison of union and enum literal is only valid for tagged union types", .{});
};
// Coerce both the union and the tag to the union's tag type, and then execute the
// enum comparison codepath.
const coerced_tag = try sema.coerce(block, union_tag_ty, tag, tag_src);
const coerced_union = try sema.coerce(block, union_tag_ty, un, un_src);
return sema.cmpSelf(block, coerced_union, coerced_tag, op, un_src, tag_src);
}
/// Only called for non-equality operators. See also `zirCmpEq`.
fn zirCmp(
sema: *Sema,
@ -6567,10 +6596,21 @@ fn analyzeCmp(
@tagName(op), resolved_type,
});
}
const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
return sema.cmpSelf(block, casted_lhs, casted_rhs, op, lhs_src, rhs_src);
}
fn cmpSelf(
sema: *Sema,
block: *Scope.Block,
casted_lhs: Air.Inst.Ref,
casted_rhs: Air.Inst.Ref,
op: std.math.CompareOperator,
lhs_src: LazySrcLoc,
rhs_src: LazySrcLoc,
) CompileError!Air.Inst.Ref {
const resolved_type = sema.typeOf(casted_lhs);
const runtime_src: LazySrcLoc = src: {
if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type);
@ -9919,9 +9959,9 @@ fn coerce(
}
}
},
.Enum => {
// enum literal to enum
if (inst_ty.zigTypeTag() == .EnumLiteral) {
.Enum => switch (inst_ty.zigTypeTag()) {
.EnumLiteral => {
// enum literal to enum
const val = try sema.resolveConstValue(block, inst_src, inst);
const bytes = val.castTag(.enum_literal).?.data;
const resolved_dest_type = try sema.resolveTypeFields(block, inst_src, dest_type);
@ -9948,7 +9988,15 @@ fn coerce(
resolved_dest_type,
try Value.Tag.enum_field_index.create(arena, @intCast(u32, field_index)),
);
}
},
.Union => blk: {
// union to its own tag type
const union_tag_ty = inst_ty.unionTagType() orelse break :blk;
if (union_tag_ty.eql(dest_type)) {
return sema.unionToTag(block, dest_type, inst, inst_src);
}
},
else => {},
},
.ErrorUnion => {
// T to E!T or E to E!T
@ -10802,6 +10850,20 @@ fn wrapErrorUnion(
}
}
fn unionToTag(
sema: *Sema,
block: *Scope.Block,
dest_type: Type,
un: Air.Inst.Ref,
un_src: LazySrcLoc,
) !Air.Inst.Ref {
if (try sema.resolveMaybeUndefVal(block, un_src, un)) |un_val| {
return sema.addConstant(dest_type, un_val.unionTag());
}
try sema.requireRuntimeBlock(block, un_src);
return block.addTyOp(.get_union_tag, dest_type, un);
}
fn resolvePeerTypes(
sema: *Sema,
block: *Scope.Block,

View File

@ -890,6 +890,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
.memcpy => try self.airMemcpy(inst),
.memset => try self.airMemset(inst),
.set_union_tag => try self.airSetUnionTag(inst),
.get_union_tag => try self.airGetUnionTag(inst),
.atomic_store_unordered => try self.airAtomicStore(inst, .Unordered),
.atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic),
@ -1552,6 +1553,14 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
}
fn airGetUnionTag(self: *Self, inst: Air.Inst.Index) !void {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else switch (arch) {
else => return self.fail("TODO implement airGetUnionTag for {}", .{self.target.cpu.arch}),
};
return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
}
fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
if (!self.liveness.operandDies(inst, op_index))
return false;

View File

@ -956,6 +956,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
.memset => try airMemset(f, inst),
.memcpy => try airMemcpy(f, inst),
.set_union_tag => try airSetUnionTag(f, inst),
.get_union_tag => try airGetUnionTag(f, inst),
.int_to_float,
.float_to_int,
@ -2096,6 +2097,22 @@ fn airSetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
return CValue.none;
}
fn airGetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
if (f.liveness.isUnused(inst))
return CValue.none;
const inst_ty = f.air.typeOfIndex(inst);
const local = try f.allocLocal(inst_ty, .Const);
const ty_op = f.air.instructions.items(.data)[inst].ty_op;
const writer = f.object.writer();
const operand = try f.resolveInst(ty_op.operand);
try writer.writeAll("get_union_tag(");
try f.writeCValue(writer, operand);
try writer.writeAll(");\n");
return local;
}
fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 {
return switch (order) {
.Unordered => "memory_order_relaxed",

View File

@ -1304,6 +1304,7 @@ pub const FuncGen = struct {
.memset => try self.airMemset(inst),
.memcpy => try self.airMemcpy(inst),
.set_union_tag => try self.airSetUnionTag(inst),
.get_union_tag => try self.airGetUnionTag(inst),
.atomic_store_unordered => try self.airAtomicStore(inst, .Unordered),
.atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic),
@ -2557,6 +2558,18 @@ pub const FuncGen = struct {
return null;
}
fn airGetUnionTag(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
if (self.liveness.isUnused(inst))
return null;
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const un_ty = self.air.typeOf(ty_op.operand);
const un = try self.resolveInst(ty_op.operand);
_ = un_ty; // TODO handle when onlyTagHasCodegenBits() == true and other union forms
return self.builder.buildExtractValue(un, 1, "");
}
fn fieldPtr(
self: *FuncGen,
inst: Air.Inst.Index,

View File

@ -179,6 +179,7 @@ const Writer = struct {
.array_to_slice,
.int_to_float,
.float_to_int,
.get_union_tag,
=> try w.writeTyOp(s, inst),
.block,

View File

@ -2487,6 +2487,12 @@ pub const Type = extern union {
};
}
pub fn unionFieldType(ty: Type, enum_tag: Value) Type {
const union_obj = ty.cast(Payload.Union).?.data;
const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag).?;
return union_obj.fields.values()[index].ty;
}
/// Asserts that the type is an error union.
pub fn errorUnionPayload(self: Type) Type {
return switch (self.tag()) {
@ -3801,6 +3807,8 @@ pub const Type = extern union {
};
};
pub const @"bool" = initTag(.bool);
pub fn ptr(arena: *Allocator, d: Payload.Pointer.Data) !Type {
assert(d.host_size == 0 or d.bit_offset < d.host_size * 8);

View File

@ -1275,7 +1275,12 @@ pub const Value = extern union {
}
},
.Union => {
@panic("TODO implement hashing union values");
const union_obj = val.castTag(.@"union").?.data;
if (ty.unionTagType()) |tag_ty| {
union_obj.tag.hash(tag_ty, hasher);
}
const active_field_ty = ty.unionFieldType(union_obj.tag);
union_obj.val.hash(active_field_ty, hasher);
},
.Fn => {
@panic("TODO implement hashing function values");
@ -1431,6 +1436,14 @@ pub const Value = extern union {
}
}
pub fn unionTag(val: Value) Value {
switch (val.tag()) {
.undef => return val,
.@"union" => return val.castTag(.@"union").?.data.tag,
else => unreachable,
}
}
/// Returns a pointer to the element value at the index.
pub fn elemPtr(self: Value, allocator: *Allocator, index: usize) !Value {
if (self.castTag(.elem_ptr)) |elem_ptr| {

View File

@ -14,3 +14,21 @@ test "basic unions" {
foo = Foo{ .float = 12.34 };
try expect(foo.float == 12.34);
}
test "init union with runtime value" {
var foo: Foo = undefined;
setFloat(&foo, 12.34);
try expect(foo.float == 12.34);
setInt(&foo, 42);
try expect(foo.int == 42);
}
fn setFloat(foo: *Foo, x: f64) void {
foo.* = Foo{ .float = x };
}
fn setInt(foo: *Foo, x: i32) void {
foo.* = Foo{ .int = x };
}

View File

@ -49,24 +49,6 @@ test "comptime union field access" {
}
}
test "init union with runtime value" {
var foo: Foo = undefined;
setFloat(&foo, 12.34);
try expect(foo.float == 12.34);
setInt(&foo, 42);
try expect(foo.int == 42);
}
fn setFloat(foo: *Foo, x: f64) void {
foo.* = Foo{ .float = x };
}
fn setInt(foo: *Foo, x: i32) void {
foo.* = Foo{ .int = x };
}
const FooExtern = extern union {
float: f64,
int: i32,
@ -185,12 +167,13 @@ test "union field access gives the enum values" {
}
test "cast union to tag type of union" {
try testCastUnionToTag(TheUnion{ .B = 1234 });
comptime try testCastUnionToTag(TheUnion{ .B = 1234 });
try testCastUnionToTag();
comptime try testCastUnionToTag();
}
fn testCastUnionToTag(x: TheUnion) !void {
try expect(@as(TheTag, x) == TheTag.B);
fn testCastUnionToTag() !void {
var u = TheUnion{ .B = 1234 };
try expect(@as(TheTag, u) == TheTag.B);
}
test "cast tag type of union to union" {