llvm: add safety-check for Wasm memset

When lowering the `memset` instruction, LLVM will lower it to WebAssembly's
`memory.fill` instruction when the bulk-memory feature is enabled. This
instruction will trap when the destination address is out-of-bounds.
By Zig's semantics, it is valid to have an invalid pointer when the length is 0.
To prevent runtimes from trapping, we add a safety-check for slices to only
lower to a memset instruction when the length is larger than 0.
This commit is contained in:
Luuk de Gram 2023-07-04 20:12:48 +02:00
parent 836f9fceab
commit d54ebf4356
No known key found for this signature in database
GPG Key ID: A8CFE58E4DC7D664

View File

@ -8511,6 +8511,14 @@ pub const FuncGen = struct {
const dest_ptr = self.sliceOrArrayPtr(dest_slice, ptr_ty);
const is_volatile = ptr_ty.isVolatilePtr(mod);
// Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless
// of the length. This means we need to emit a check where we skip the memset when the length
// is 0 as we allow for undefined pointers in 0-sized slices.
const needs_wasm_safety_check = safety and
o.target.isWasm() and
ptr_ty.isSlice(mod) and
std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory);
if (try self.air.value(bin_op.rhs, mod)) |elem_val| {
if (elem_val.isUndefDeep(mod)) {
// Even if safety is disabled, we still emit a memset to undefined since it conveys
@ -8521,7 +8529,11 @@ pub const FuncGen = struct {
else
u8_llvm_ty.getUndef();
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
if (needs_wasm_safety_check) {
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
} else {
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
}
if (safety and mod.comp.bin_file.options.valgrind) {
self.valgrindMarkUndef(dest_ptr, len);
@ -8539,7 +8551,12 @@ pub const FuncGen = struct {
.val = byte_val,
});
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
if (needs_wasm_safety_check) {
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
} else {
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
}
return null;
}
}
@ -8551,7 +8568,12 @@ pub const FuncGen = struct {
// In this case we can take advantage of LLVM's intrinsic.
const fill_byte = try self.bitCast(value, elem_ty, Type.u8);
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
if (needs_wasm_safety_check) {
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
} else {
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
}
return null;
}
@ -8622,6 +8644,29 @@ pub const FuncGen = struct {
return null;
}
fn safeWasmMemset(
self: *FuncGen,
dest_ptr: *llvm.Value,
fill_byte: *llvm.Value,
len: *llvm.Value,
dest_ptr_align: u32,
is_volatile: bool,
) !void {
const parent_block = self.context.createBasicBlock("Block");
const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq);
const then_block = self.context.appendBasicBlock(self.llvm_func, "Then");
const else_block = self.context.appendBasicBlock(self.llvm_func, "Else");
_ = self.builder.buildCondBr(cond, then_block, else_block);
self.builder.positionBuilderAtEnd(then_block);
_ = self.builder.buildBr(parent_block);
self.builder.positionBuilderAtEnd(else_block);
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
_ = self.builder.buildBr(parent_block);
self.llvm_func.appendExistingBasicBlock(parent_block);
self.builder.positionBuilderAtEnd(parent_block);
}
fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
const o = self.dg.object;
const mod = o.module;