Merge pull request #17352 from kcbanner/extern_union_comptime_memory

sema: Support reinterpreting extern/packed unions at comptime via field access
This commit is contained in:
Andrew Kelley 2023-10-03 11:20:08 -07:00 committed by GitHub
commit 87d09edf2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 376 additions and 109 deletions

View File

@ -6607,6 +6607,7 @@ pub fn unionFieldNormalAlignment(mod: *Module, u: InternPool.UnionType, field_in
return field_ty.abiAlignment(mod);
}
/// Returns the index of the active field, given the current tag value
pub fn unionTagFieldIndex(mod: *Module, u: InternPool.UnionType, enum_tag: Value) ?u32 {
const ip = &mod.intern_pool;
if (enum_tag.toIntern() == .none) return null;

View File

@ -27258,7 +27258,7 @@ fn unionFieldVal(
return sema.failWithOwnedErrorMsg(block, msg);
}
},
.Packed, .Extern => {
.Packed, .Extern => |layout| {
if (tag_matches) {
return Air.internedToRef(un.val);
} else {
@ -27267,7 +27267,7 @@ fn unionFieldVal(
else
union_ty.unionFieldType(un.tag.toValue(), mod).?;
if (try sema.bitCastVal(block, src, un.val.toValue(), old_ty, field_ty, 0)) |new_val| {
if (try sema.bitCastUnionFieldVal(block, src, un.val.toValue(), old_ty, field_ty, layout)) |new_val| {
return Air.internedToRef(new_val.toIntern());
}
}
@ -29788,13 +29788,19 @@ fn storePtrVal(
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{mut_kit.ty.fmt(mod)}),
};
operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.ReinterpretDeclRef => unreachable,
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
};
if (reinterpret.write_packed) {
operand_val.writeToPackedMemory(operand_ty, mod, buffer[reinterpret.byte_offset..], 0) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.ReinterpretDeclRef => unreachable,
};
} else {
operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.ReinterpretDeclRef => unreachable,
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
};
}
const val = Value.readFromMemory(mut_kit.ty, mod, buffer, sema.arena) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.IllDefinedMemoryLayout => unreachable,
@ -29826,6 +29832,8 @@ const ComptimePtrMutationKit = struct {
reinterpret: struct {
val_ptr: *Value,
byte_offset: usize,
/// If set, write the operand to packed memory
write_packed: bool = false,
},
/// If the root decl could not be used as parent, this means `ty` is the type that
/// caused that by not having a well-defined layout.
@ -30189,21 +30197,43 @@ fn beginComptimePtrMutation(
);
},
.@"union" => {
// We need to set the active field of the union.
const union_tag_ty = base_child_ty.unionTagTypeHypothetical(mod);
const payload = &val_ptr.castTag(.@"union").?.data;
payload.tag = try mod.enumValueFieldIndex(union_tag_ty, field_index);
const layout = base_child_ty.containerLayout(mod);
return beginComptimePtrMutationInner(
sema,
block,
src,
parent.ty.structFieldType(field_index, mod),
&payload.val,
ptr_elem_ty,
parent.mut_decl,
);
const tag_type = base_child_ty.unionTagTypeHypothetical(mod);
const hypothetical_tag = try mod.enumValueFieldIndex(tag_type, field_index);
if (layout == .Auto or (payload.tag != null and hypothetical_tag.eql(payload.tag.?, tag_type, mod))) {
// We need to set the active field of the union.
payload.tag = hypothetical_tag;
const field_ty = parent.ty.structFieldType(field_index, mod);
return beginComptimePtrMutationInner(
sema,
block,
src,
field_ty,
&payload.val,
ptr_elem_ty,
parent.mut_decl,
);
} else {
// Writing to a different field (a different or unknown tag is active) requires reinterpreting
// memory of the entire union, which requires knowing its abiSize.
try sema.resolveTypeLayout(parent.ty);
// This union value no longer has a well-defined tag type.
// The reinterpretation will read it back out as .none.
payload.val = try payload.val.unintern(sema.arena, mod);
return ComptimePtrMutationKit{
.mut_decl = parent.mut_decl,
.pointee = .{ .reinterpret = .{
.val_ptr = val_ptr,
.byte_offset = 0,
.write_packed = layout == .Packed,
} },
.ty = parent.ty,
};
}
},
.slice => switch (field_index) {
Value.slice_ptr_index => return beginComptimePtrMutationInner(
@ -30704,6 +30734,7 @@ fn bitCastVal(
// For types with well-defined memory layouts, we serialize them a byte buffer,
// then deserialize to the new type.
const abi_size = try sema.usizeCast(block, src, old_ty.abiSize(mod));
const buffer = try sema.gpa.alloc(u8, abi_size);
defer sema.gpa.free(buffer);
val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) {
@ -30720,6 +30751,63 @@ fn bitCastVal(
};
}
fn bitCastUnionFieldVal(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
val: Value,
old_ty: Type,
field_ty: Type,
layout: std.builtin.Type.ContainerLayout,
) !?Value {
const mod = sema.mod;
if (old_ty.eql(field_ty, mod)) return val;
const old_size = try sema.usizeCast(block, src, old_ty.abiSize(mod));
const field_size = try sema.usizeCast(block, src, field_ty.abiSize(mod));
const endian = mod.getTarget().cpu.arch.endian();
const buffer = try sema.gpa.alloc(u8, @max(old_size, field_size));
defer sema.gpa.free(buffer);
// Reading a larger value means we need to reinterpret from undefined bytes.
const offset = switch (layout) {
.Extern => offset: {
if (field_size > old_size) @memset(buffer[old_size..], 0xaa);
val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.ReinterpretDeclRef => return null,
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{old_ty.fmt(mod)}),
};
break :offset 0;
},
.Packed => offset: {
if (field_size > old_size) {
const min_size = @max(old_size, 1);
switch (endian) {
.Little => @memset(buffer[min_size - 1 ..], 0xaa),
.Big => @memset(buffer[0 .. buffer.len - min_size + 1], 0xaa),
}
}
val.writeToPackedMemory(old_ty, mod, buffer, 0) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.ReinterpretDeclRef => return null,
};
break :offset if (endian == .Big) buffer.len - field_size else 0;
},
.Auto => unreachable,
};
return Value.readFromMemory(field_ty, mod, buffer[offset..], sema.arena) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.IllDefinedMemoryLayout => unreachable,
error.Unimplemented => return sema.fail(block, src, "TODO: implement readFromMemory for type '{}'", .{field_ty.fmt(mod)}),
};
}
fn coerceArrayPtrToSlice(
sema: *Sema,
block: *Block,

View File

@ -84,22 +84,27 @@ pub fn print(
if (level == 0) {
return writer.writeAll(".{ ... }");
}
const union_val = val.castTag(.@"union").?.data;
const payload = val.castTag(.@"union").?.data;
try writer.writeAll(".{ ");
if (union_val.tag.toIntern() != .none) {
if (payload.tag) |tag| {
try print(.{
.ty = ip.indexToKey(ty.toIntern()).union_type.enum_tag_ty.toType(),
.val = union_val.tag,
.val = tag,
}, writer, level - 1, mod);
try writer.writeAll(" = ");
const field_ty = ty.unionFieldType(union_val.tag, mod).?;
const field_ty = ty.unionFieldType(tag, mod).?;
try print(.{
.ty = field_ty,
.val = union_val.val,
.val = payload.val,
}, writer, level - 1, mod);
} else {
return writer.writeAll("(unknown tag)");
try writer.writeAll("(unknown tag) = ");
const backing_ty = try ty.unionBackingType(mod);
try print(.{
.ty = backing_ty,
.val = payload.val,
}, writer, level - 1, mod);
}
return writer.writeAll(" }");
@ -421,7 +426,12 @@ pub fn print(
.val = un.val.toValue(),
}, writer, level - 1, mod);
} else {
try writer.writeAll("(unknown tag)");
try writer.writeAll("(unknown tag) = ");
const backing_ty = try ty.unionBackingType(mod);
try print(.{
.ty = backing_ty,
.val = un.val.toValue(),
}, writer, level - 1, mod);
}
} else try writer.writeAll("...");
return writer.writeAll(" }");

View File

@ -1954,6 +1954,16 @@ pub const Type = struct {
return true;
}
/// Returns the type used for backing storage of this union during comptime operations.
/// Asserts the type is either an extern or packed union.
pub fn unionBackingType(ty: Type, mod: *Module) !Type {
return switch (ty.containerLayout(mod)) {
.Extern => try mod.arrayType(.{ .len = ty.abiSize(mod), .child = .u8_type }),
.Packed => try mod.intType(.unsigned, @intCast(ty.bitSize(mod))),
.Auto => unreachable,
};
}
pub fn unionGetLayout(ty: Type, mod: *Module) Module.UnionLayout {
const ip = &mod.intern_pool;
const union_type = ip.indexToKey(ty.toIntern()).union_type;

View File

@ -327,11 +327,19 @@ pub const Value = struct {
},
.@"union" => {
const pl = val.castTag(.@"union").?.data;
return mod.intern(.{ .un = .{
.ty = ty.toIntern(),
.tag = try pl.tag.intern(ty.unionTagTypeHypothetical(mod), mod),
.val = try pl.val.intern(ty.unionFieldType(pl.tag, mod).?, mod),
} });
if (pl.tag) |pl_tag| {
return mod.intern(.{ .un = .{
.ty = ty.toIntern(),
.tag = try pl_tag.intern(ty.unionTagTypeHypothetical(mod), mod),
.val = try pl.val.intern(ty.unionFieldType(pl_tag, mod).?, mod),
} });
} else {
return mod.intern(.{ .un = .{
.ty = ty.toIntern(),
.tag = .none,
.val = try pl.val.intern(try ty.unionBackingType(mod), mod),
} });
}
},
}
}
@ -399,10 +407,7 @@ pub const Value = struct {
.un => |un| Tag.@"union".create(arena, .{
// toValue asserts that the value cannot be .none which is valid on unions.
.tag = .{
.ip_index = un.tag,
.legacy = undefined,
},
.tag = if (un.tag == .none) null else un.tag.toValue(),
.val = un.val.toValue(),
}),
@ -709,21 +714,22 @@ pub const Value = struct {
.Union => switch (ty.containerLayout(mod)) {
.Auto => return error.IllDefinedMemoryLayout, // Sema is supposed to have emitted a compile error already
.Extern => {
const union_obj = mod.typeToUnion(ty).?;
if (val.unionTag(mod)) |union_tag| {
const union_obj = mod.typeToUnion(ty).?;
const field_index = mod.unionTagFieldIndex(union_obj, union_tag).?;
const field_type = union_obj.field_types.get(&mod.intern_pool)[field_index].toType();
const field_val = try val.fieldValue(mod, field_index);
const byte_count = @as(usize, @intCast(field_type.abiSize(mod)));
return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]);
} else {
const union_size = ty.abiSize(mod);
const array_type = try mod.arrayType(.{ .len = union_size, .child = .u8_type });
return writeToMemory(val.unionValue(mod), array_type, mod, buffer[0..@as(usize, @intCast(union_size))]);
const backing_ty = try ty.unionBackingType(mod);
const byte_count: usize = @intCast(backing_ty.abiSize(mod));
return writeToMemory(val.unionValue(mod), backing_ty, mod, buffer[0..byte_count]);
}
},
.Packed => {
const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
const backing_ty = try ty.unionBackingType(mod);
const byte_count: usize = @intCast(backing_ty.abiSize(mod));
return writeToPackedMemory(val, ty, mod, buffer[0..byte_count], 0);
},
},
@ -842,9 +848,8 @@ pub const Value = struct {
const field_val = try val.fieldValue(mod, field_index);
return field_val.writeToPackedMemory(field_type, mod, buffer, bit_offset);
} else {
const union_bits: u16 = @intCast(ty.bitSize(mod));
const int_ty = try mod.intType(.unsigned, union_bits);
return val.unionValue(mod).writeToPackedMemory(int_ty, mod, buffer, bit_offset);
const backing_ty = try ty.unionBackingType(mod);
return val.unionValue(mod).writeToPackedMemory(backing_ty, mod, buffer, bit_offset);
}
},
}
@ -1146,10 +1151,8 @@ pub const Value = struct {
.Union => switch (ty.containerLayout(mod)) {
.Auto, .Extern => unreachable, // Handled by non-packed readFromMemory
.Packed => {
const union_bits: u16 = @intCast(ty.bitSize(mod));
assert(union_bits != 0);
const int_ty = try mod.intType(.unsigned, union_bits);
const val = (try readFromPackedMemory(int_ty, mod, buffer, bit_offset, arena)).toIntern();
const backing_ty = try ty.unionBackingType(mod);
const val = (try readFromPackedMemory(backing_ty, mod, buffer, bit_offset, arena)).toIntern();
return (try mod.intern(.{ .un = .{
.ty = ty.toIntern(),
.tag = .none,
@ -4017,7 +4020,7 @@ pub const Value = struct {
data: Data,
pub const Data = struct {
tag: Value,
tag: ?Value,
val: Value,
};
};

View File

@ -455,54 +455,3 @@ test "type pun null pointer-like optional" {
// note that expectEqual hides the bug
try testing.expect(@as(*const ?*i8, @ptrCast(&p)).* == null);
}
test "reinterpret extern union" {
{
const U = extern union {
a: u32,
b: u8 align(8),
};
comptime var u: U = undefined;
comptime @memset(std.mem.asBytes(&u), 42);
try comptime testing.expect(0x2a2a2a2a == u.a);
try comptime testing.expect(42 == u.b);
try testing.expectEqual(@as(u32, 0x2a2a2a2a), u.a);
try testing.expectEqual(42, u.b);
}
}
test "reinterpret packed union" {
{
const U = packed union {
a: u32,
b: u8 align(8),
};
comptime var u: U = undefined;
comptime @memset(std.mem.asBytes(&u), 42);
try comptime testing.expect(0x2a2a2a2a == u.a);
try comptime testing.expect(0x2a == u.b);
try testing.expectEqual(@as(u32, 0x2a2a2a2a), u.a);
try testing.expectEqual(0x2a, u.b);
}
{
const U = packed union {
a: u7,
b: u1,
};
const S = packed struct {
lsb: U,
msb: U,
};
comptime var s: S = undefined;
comptime @memset(std.mem.asBytes(&s), 0xaa);
try comptime testing.expectEqual(@as(u7, 0x2a), s.lsb.a);
try comptime testing.expectEqual(@as(u1, 0), s.lsb.b);
try comptime testing.expectEqual(@as(u7, 0x55), s.msb.a);
try comptime testing.expectEqual(@as(u1, 1), s.msb.b);
}
}

View File

@ -1,5 +1,6 @@
const builtin = @import("builtin");
const std = @import("std");
const endian = builtin.cpu.arch.endian();
const expect = std.testing.expect;
const assert = std.debug.assert;
const expectEqual = std.testing.expectEqual;
@ -1660,15 +1661,220 @@ test "union with 128 bit integer" {
}
}
test "memset extern union at comptime" {
test "memset extern union" {
const U = extern union {
foo: u8,
bar: u32,
};
const u = comptime blk: {
var u: U = undefined;
@memset(std.mem.asBytes(&u), 0);
u.foo = 0;
break :blk u;
const S = struct {
fn doTheTest() !void {
var u: U = undefined;
@memset(std.mem.asBytes(&u), 0);
try expectEqual(@as(u8, 0), u.foo);
try expectEqual(@as(u32, 0), u.bar);
}
};
try expect(u.foo == 0);
try comptime S.doTheTest();
try S.doTheTest();
}
test "memset packed union" {
const U = packed union {
a: u32,
b: u8,
};
const S = struct {
fn doTheTest() !void {
var u: U = undefined;
@memset(std.mem.asBytes(&u), 42);
try expectEqual(@as(u32, 0x2a2a2a2a), u.a);
try expectEqual(@as(u8, 0x2a), u.b);
}
};
try comptime S.doTheTest();
if (builtin.cpu.arch.isWasm()) return error.SkipZigTest; // TODO
try S.doTheTest();
}
fn littleToNativeEndian(comptime T: type, v: T) T {
return if (endian == .Little) v else @byteSwap(v);
}
test "reinterpret extern union" {
const U = extern union {
foo: u8,
baz: u32 align(8),
bar: u32,
};
const S = struct {
fn doTheTest() !void {
{
// Undefined initialization
const u = blk: {
var u: U = undefined;
@memset(std.mem.asBytes(&u), 0);
u.bar = 0xbbbbbbbb;
u.foo = 0x2a;
break :blk u;
};
try expectEqual(@as(u8, 0x2a), u.foo);
try expectEqual(littleToNativeEndian(u32, 0xbbbbbb2a), u.bar);
try expectEqual(littleToNativeEndian(u32, 0xbbbbbb2a), u.baz);
}
{
// Union initialization
var u: U = .{
.foo = 0x2a,
};
{
const expected, const mask = switch (endian) {
.Little => .{ 0x2a, 0xff },
.Big => .{ 0x2a000000, 0xff000000 },
};
try expectEqual(@as(u8, 0x2a), u.foo);
try expectEqual(@as(u32, expected), u.bar & mask);
try expectEqual(@as(u32, expected), u.baz & mask);
}
// Writing to a larger field
u.baz = 0xbbbbbbbb;
try expectEqual(@as(u8, 0xbb), u.foo);
try expectEqual(@as(u32, 0xbbbbbbbb), u.bar);
try expectEqual(@as(u32, 0xbbbbbbbb), u.baz);
// Writing to the same field
u.baz = 0xcccccccc;
try expectEqual(@as(u8, 0xcc), u.foo);
try expectEqual(@as(u32, 0xcccccccc), u.bar);
try expectEqual(@as(u32, 0xcccccccc), u.baz);
// Writing to a smaller field
u.foo = 0xdd;
try expectEqual(@as(u8, 0xdd), u.foo);
try expectEqual(littleToNativeEndian(u32, 0xccccccdd), u.bar);
try expectEqual(littleToNativeEndian(u32, 0xccccccdd), u.baz);
}
}
};
try comptime S.doTheTest();
if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; // TODO
try S.doTheTest();
}
test "reinterpret packed union" {
const U = packed union {
foo: u8,
bar: u29,
baz: u64,
qux: u12,
};
const S = struct {
fn doTheTest() !void {
{
const u = blk: {
var u: U = undefined;
@memset(std.mem.asBytes(&u), 0);
u.baz = 0xbbbbbbbb;
u.qux = 0xe2a;
break :blk u;
};
try expectEqual(@as(u8, 0x2a), u.foo);
try expectEqual(@as(u12, 0xe2a), u.qux);
// https://github.com/ziglang/zig/issues/17360
if (@inComptime()) {
try expectEqual(@as(u29, 0x1bbbbe2a), u.bar);
try expectEqual(@as(u64, 0xbbbbbe2a), u.baz);
}
}
{
// Union initialization
var u: U = .{
.qux = 0xe2a,
};
try expectEqual(@as(u8, 0x2a), u.foo);
try expectEqual(@as(u12, 0xe2a), u.qux);
try expectEqual(@as(u29, 0xe2a), u.bar & 0xfff);
try expectEqual(@as(u64, 0xe2a), u.baz & 0xfff);
// Writing to a larger field
u.baz = 0xbbbbbbbb;
try expectEqual(@as(u8, 0xbb), u.foo);
try expectEqual(@as(u12, 0xbbb), u.qux);
try expectEqual(@as(u29, 0x1bbbbbbb), u.bar);
try expectEqual(@as(u64, 0xbbbbbbbb), u.baz);
// Writing to the same field
u.baz = 0xcccccccc;
try expectEqual(@as(u8, 0xcc), u.foo);
try expectEqual(@as(u12, 0xccc), u.qux);
try expectEqual(@as(u29, 0x0ccccccc), u.bar);
try expectEqual(@as(u64, 0xcccccccc), u.baz);
// Writing to a smaller field
u.foo = 0xdd;
try expectEqual(@as(u8, 0xdd), u.foo);
try expectEqual(@as(u12, 0xcdd), u.qux);
try expectEqual(@as(u29, 0x0cccccdd), u.bar);
try expectEqual(@as(u64, 0xccccccdd), u.baz);
}
}
};
try comptime S.doTheTest();
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.cpu.arch.isPPC()) return error.SkipZigTest; // TODO
if (builtin.cpu.arch.isWasm()) return error.SkipZigTest; // TODO
try S.doTheTest();
}
test "reinterpret packed union inside packed struct" {
const U = packed union {
a: u7,
b: u1,
};
const V = packed struct {
lo: U,
hi: U,
};
const S = struct {
fn doTheTest() !void {
var v: V = undefined;
@memset(std.mem.asBytes(&v), 0x55);
try expectEqual(@as(u7, 0x55), v.lo.a);
try expectEqual(@as(u1, 1), v.lo.b);
try expectEqual(@as(u7, 0x2a), v.hi.a);
try expectEqual(@as(u1, 0), v.hi.b);
v.lo.b = 0;
try expectEqual(@as(u7, 0x54), v.lo.a);
try expectEqual(@as(u1, 0), v.lo.b);
v.hi.b = 1;
try expectEqual(@as(u7, 0x2b), v.hi.a);
try expectEqual(@as(u1, 1), v.hi.b);
}
};
try comptime S.doTheTest();
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
try S.doTheTest();
}