From a5d4ad17b716508c2e1a2c1c0cf0b32bed08e26f Mon Sep 17 00:00:00 2001 From: Frank Denis <124872+jedisct1@users.noreply.github.com> Date: Wed, 20 Nov 2024 11:16:09 +0100 Subject: [PATCH] crypto.keccak.State: add checks to prevent insecure transitions (#22020) * crypto.keccak.State: don't unconditionally permute after a squeeze() Now, squeeze() behaves like absorb() Namely, squeeze(x[0..t]); squeeze(x[t..n)); with t <= n becomes equivalent to squeeze(x[0..n]). * keccak: in debug mode, track transitions to prevent insecure ones. Fixes #22019 --- lib/std/crypto/keccak_p.zig | 139 +++++++++++++++++++++++++++++++----- 1 file changed, 122 insertions(+), 17 deletions(-) diff --git a/lib/std/crypto/keccak_p.zig b/lib/std/crypto/keccak_p.zig index 0c522d5148..40b9ba43d3 100644 --- a/lib/std/crypto/keccak_p.zig +++ b/lib/std/crypto/keccak_p.zig @@ -4,6 +4,7 @@ const assert = std.debug.assert; const math = std.math; const mem = std.mem; const native_endian = builtin.cpu.arch.endian(); +const mode = @import("builtin").mode; /// The Keccak-f permutation. pub fn KeccakF(comptime f: u11) type { @@ -199,6 +200,46 @@ pub fn State(comptime f: u11, comptime capacity: u11, comptime rounds: u5) type comptime assert(f >= 200 and f <= 1600 and f % 200 == 0); // invalid state size comptime assert(capacity < f and capacity % 8 == 0); // invalid capacity size + // In debug mode, track transitions to prevent insecure ones. + const Op = enum { uninitialized, initialized, updated, absorb, squeeze }; + const TransitionTracker = if (mode == .Debug) struct { + op: Op = .uninitialized, + + fn to(tracker: *@This(), next_op: Op) void { + switch (next_op) { + .updated => { + switch (tracker.op) { + .uninitialized => @panic("cannot permute before initializing"), + else => {}, + } + }, + .absorb => { + switch (tracker.op) { + .squeeze => @panic("cannot absorb right after squeezing"), + else => {}, + } + }, + .squeeze => { + switch (tracker.op) { + .uninitialized => @panic("cannot squeeze before initializing"), + .initialized => @panic("cannot squeeze right after initializing"), + .absorb => @panic("cannot squeeze right after absorbing"), + else => {}, + } + }, + .uninitialized => @panic("cannot transition to uninitialized"), + .initialized => {}, + } + tracker.op = next_op; + } + } else struct { + // No-op in non-debug modes. + inline fn to(tracker: *@This(), next_op: Op) void { + _ = tracker; // no-op + _ = next_op; // no-op + } + }; + return struct { const Self = @This(); @@ -215,67 +256,108 @@ pub fn State(comptime f: u11, comptime capacity: u11, comptime rounds: u5) type st: KeccakF(f) = .{}, + transition: TransitionTracker = .{}, + /// Absorb a slice of bytes into the sponge. - pub fn absorb(self: *Self, bytes_: []const u8) void { - var bytes = bytes_; + pub fn absorb(self: *Self, bytes: []const u8) void { + self.transition.to(.absorb); + var i: usize = 0; if (self.offset > 0) { const left = @min(rate - self.offset, bytes.len); @memcpy(self.buf[self.offset..][0..left], bytes[0..left]); self.offset += left; + if (left == bytes.len) return; if (self.offset == rate) { - self.offset = 0; self.st.addBytes(self.buf[0..]); self.st.permuteR(rounds); + self.offset = 0; } - if (left == bytes.len) return; - bytes = bytes[left..]; + i = left; } - while (bytes.len >= rate) { - self.st.addBytes(bytes[0..rate]); + while (i + rate < bytes.len) : (i += rate) { + self.st.addBytes(bytes[i..][0..rate]); self.st.permuteR(rounds); - bytes = bytes[rate..]; } - if (bytes.len > 0) { - @memcpy(self.buf[0..bytes.len], bytes); - self.offset = bytes.len; + const left = bytes.len - i; + if (left > 0) { + @memcpy(self.buf[0..left], bytes[i..][0..left]); } + self.offset = left; } /// Initialize the state from a slice of bytes. - pub fn init(bytes: [f / 8]u8) Self { - return .{ .st = KeccakF(f).init(bytes) }; + pub fn init(bytes: [f / 8]u8, delim: u8) Self { + var st = Self{ .st = KeccakF(f).init(bytes), .delim = delim }; + st.transition.to(.initialized); + return st; } /// Permute the state pub fn permute(self: *Self) void { + if (mode == .Debug) { + if (self.transition.op == .absorb and self.offset > 0) { + @panic("cannot permute with pending input - call fillBlock() or pad() instead"); + } + } + self.transition.to(.updated); self.st.permuteR(rounds); self.offset = 0; } - /// Align the input to the rate boundary. + /// Align the input to the rate boundary and permute. pub fn fillBlock(self: *Self) void { + self.transition.to(.absorb); self.st.addBytes(self.buf[0..self.offset]); self.st.permuteR(rounds); self.offset = 0; + self.transition.to(.updated); } /// Mark the end of the input. pub fn pad(self: *Self) void { + self.transition.to(.absorb); self.st.addBytes(self.buf[0..self.offset]); + if (self.offset == rate) { + self.st.permuteR(rounds); + self.offset = 0; + } self.st.addByte(self.delim, self.offset); self.st.addByte(0x80, rate - 1); self.st.permuteR(rounds); self.offset = 0; + self.transition.to(.updated); } /// Squeeze a slice of bytes from the sponge. + /// The function can be called multiple times. pub fn squeeze(self: *Self, out: []u8) void { + self.transition.to(.squeeze); var i: usize = 0; - while (i < out.len) : (i += rate) { - const left = @min(rate, out.len - i); - self.st.extractBytes(out[i..][0..left]); + if (self.offset == rate) { + self.st.permuteR(rounds); + } else if (self.offset > 0) { + @branchHint(.unlikely); + var buf: [rate]u8 = undefined; + self.st.extractBytes(buf[0..]); + const left = @min(rate - self.offset, out.len); + @memcpy(out[0..left], buf[self.offset..][0..left]); + self.offset += left; + if (left == out.len) return; + if (self.offset == rate) { + self.offset = 0; + self.st.permuteR(rounds); + } + i = left; + } + while (i + rate < out.len) : (i += rate) { + self.st.extractBytes(out[i..][0..rate]); self.st.permuteR(rounds); } + const left = out.len - i; + if (left > 0) { + self.st.extractBytes(out[i..][0..left]); + } + self.offset = left; } }; } @@ -298,3 +380,26 @@ test "Keccak-f800" { }; try std.testing.expectEqualSlices(u32, &st.st, &expected); } + +test "squeeze" { + var st = State(800, 256, 22).init([_]u8{0x80} ** 100, 0x01); + + var out0: [15]u8 = undefined; + var out1: [out0.len]u8 = undefined; + st.permute(); + var st0 = st; + st0.squeeze(out0[0..]); + var st1 = st; + st1.squeeze(out1[0 .. out1.len / 2]); + st1.squeeze(out1[out1.len / 2 ..]); + try std.testing.expectEqualSlices(u8, &out0, &out1); + + var out2: [100]u8 = undefined; + var out3: [out2.len]u8 = undefined; + var st2 = st; + st2.squeeze(out2[0..]); + var st3 = st; + st3.squeeze(out3[0 .. out2.len / 2]); + st3.squeeze(out3[out2.len / 2 ..]); + try std.testing.expectEqualSlices(u8, &out2, &out3); +}