From a567f3871ec06f3e6a8c0e6424aba556f1069ccc Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Tue, 4 Jun 2024 22:09:15 +0200 Subject: [PATCH] spirv: improve shuffle codegen --- src/codegen/spirv.zig | 63 +++++++++++++++++++++++++---- test/behavior/shuffle.zig | 83 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 8 deletions(-) diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 215a9421f1..09185211ef 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -4082,25 +4082,72 @@ const DeclGen = struct { const b = try self.resolve(extra.b); const mask = Value.fromInterned(extra.mask); - const ty = self.typeOfIndex(inst); + // Note: number of components in the result, a, and b may differ. + const result_ty = self.typeOfIndex(inst); + const a_ty = self.typeOf(extra.a); + const b_ty = self.typeOf(extra.b); - var wip = try self.elementWise(ty, true); - defer wip.deinit(); - for (wip.results, 0..) |*result_id, i| { + const scalar_ty = result_ty.scalarType(mod); + const scalar_ty_id = try self.resolveType(scalar_ty, .direct); + + // If all of the types are SPIR-V vectors, we can use OpVectorShuffle. + if (self.isSpvVector(result_ty) and self.isSpvVector(a_ty) and self.isSpvVector(b_ty)) { + // The SPIR-V shuffle instruction is similar to the Air instruction, except that the elements are + // numbered consecutively instead of using negatives. + + const components = try self.gpa.alloc(Word, result_ty.vectorLen(mod)); + defer self.gpa.free(components); + + const a_len = a_ty.vectorLen(mod); + + for (components, 0..) |*component, i| { + const elem = try mask.elemValue(mod, i); + if (elem.isUndef(mod)) { + // This is explicitly valid for OpVectorShuffle, it indicates undefined. + component.* = 0xFFFF_FFFF; + continue; + } + + const index = elem.toSignedInt(mod); + if (index >= 0) { + component.* = @intCast(index); + } else { + component.* = @intCast(~index + a_len); + } + } + + const result_id = self.spv.allocId(); + try self.func.body.emit(self.spv.gpa, .OpVectorShuffle, .{ + .id_result_type = try self.resolveType(result_ty, .direct), + .id_result = result_id, + .vector_1 = a, + .vector_2 = b, + .components = components, + }); + return result_id; + } + + // Fall back to manually extracting and inserting components. + + const components = try self.gpa.alloc(IdRef, result_ty.vectorLen(mod)); + defer self.gpa.free(components); + + for (components, 0..) |*id, i| { const elem = try mask.elemValue(mod, i); if (elem.isUndef(mod)) { - result_id.* = try self.spv.constUndef(wip.ty_id); + id.* = try self.spv.constUndef(scalar_ty_id); continue; } const index = elem.toSignedInt(mod); if (index >= 0) { - result_id.* = try self.extractVectorComponent(wip.ty, a, @intCast(index)); + id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index)); } else { - result_id.* = try self.extractVectorComponent(wip.ty, b, @intCast(~index)); + id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index)); } } - return try wip.finalize(); + + return try self.constructVector(result_ty, components); } fn indicesToIds(self: *DeclGen, indices: []const u32) ![]IdRef { diff --git a/test/behavior/shuffle.zig b/test/behavior/shuffle.zig index fb16f3fbb3..2bcdbd1581 100644 --- a/test/behavior/shuffle.zig +++ b/test/behavior/shuffle.zig @@ -2,6 +2,7 @@ const std = @import("std"); const builtin = @import("builtin"); const mem = std.mem; const expect = std.testing.expect; +const expectEqual = std.testing.expectEqual; test "@shuffle int" { if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO @@ -49,6 +50,88 @@ test "@shuffle int" { try comptime S.doTheTest(); } +test "@shuffle int strange sizes" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + + try comptime testShuffle(2, 2, 2); + try testShuffle(2, 2, 2); + try comptime testShuffle(4, 4, 4); + try testShuffle(4, 4, 4); + try comptime testShuffle(7, 4, 4); + try testShuffle(7, 4, 4); + try comptime testShuffle(8, 6, 4); + try testShuffle(8, 6, 4); + try comptime testShuffle(2, 7, 5); + try testShuffle(2, 7, 5); + try comptime testShuffle(13, 16, 12); + try testShuffle(13, 16, 12); + try comptime testShuffle(19, 3, 17); + try testShuffle(19, 3, 17); + try comptime testShuffle(1, 10, 1); + try testShuffle(1, 10, 1); +} + +fn testShuffle( + comptime x_len: comptime_int, + comptime a_len: comptime_int, + comptime b_len: comptime_int, +) !void { + const T = i32; + const XT = @Vector(x_len, T); + const AT = @Vector(a_len, T); + const BT = @Vector(b_len, T); + + const a_elems = comptime blk: { + var elems: [a_len]T = undefined; + for (&elems, 0..) |*elem, i| elem.* = @intCast(100 + i); + break :blk elems; + }; + var a: AT = a_elems; + _ = &a; + + const b_elems = comptime blk: { + var elems: [b_len]T = undefined; + for (&elems, 0..) |*elem, i| elem.* = @intCast(1000 + i); + break :blk elems; + }; + var b: BT = b_elems; + _ = &b; + + const mask_seed: []const i32 = &.{ -14, -31, 23, 1, 21, 13, 17, -21, -10, -27, -16, -5, 15, 14, -2, 26, 2, -31, -24, -16 }; + + const mask = comptime blk: { + var elems: [x_len]i32 = undefined; + for (&elems, 0..) |*elem, i| { + const mask_val = mask_seed[i]; + if (mask_val >= 0) { + elem.* = @mod(mask_val, a_len); + } else { + elem.* = @mod(mask_val, -b_len); + } + } + + break :blk elems; + }; + + const x: XT = @shuffle(T, a, b, mask); + + const x_elems: [x_len]T = x; + for (mask, x_elems) |m, x_elem| { + if (m >= 0) { + // Element from A + try expectEqual(x_elem, a_elems[@intCast(m)]); + } else { + // Element from B + try expectEqual(x_elem, b_elems[@intCast(~m)]); + } + } +} + test "@shuffle bool 1" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO