stdlib: base64 stream decoder

This commit is contained in:
Arwalk 2024-09-08 11:28:10 +02:00
parent a5d4ad17b7
commit fdba2a939c

View File

@ -301,6 +301,33 @@ pub const Base64Decoder = struct {
if (padding_chars != padding_len) return error.InvalidPadding;
}
}
// destWriter must be compatible with std.io.Writer's writeAll interface
pub fn decodeWriter(decoder: *const Base64Decoder, destWriter: anytype, source: []const u8) !void {
var temp = [_]u8{0} ** 4;
var chunker = window(u8, source, 4, 4);
while (chunker.next()) |chunk| {
const size = try decoder.calcSizeForSlice(chunk);
try decoder.decode(&temp, chunk);
try destWriter.writeAll(temp[0..size]);
}
}
// destWriter must be compatible with std.io.Writer's writeAll interface
// sourceReader must be compatible with std.io.Reader's read interface
pub fn decodeFromReaderToWriter(decoder: *const Base64Decoder, destWriter: anytype, sourceReader: anytype) !void {
var temp = [_]u8{0} ** 3;
var tempSource = [_]u8{0} ** 4;
while (true) {
const bytesRead = try sourceReader.read(&tempSource);
if (bytesRead == 0) {
break;
}
const size = try decoder.calcSizeForSlice(tempSource[0..bytesRead]);
try decoder.decode(&temp, tempSource[0..bytesRead]);
try destWriter.writeAll(temp[0..size]);
}
}
};
pub const Base64DecoderWithIgnore = struct {
@ -332,54 +359,82 @@ pub const Base64DecoderWithIgnore = struct {
return result;
}
fn WindowWithIgnore(comptime ReaderType: type) type {
return struct {
const Self = @This();
const Err = ReaderType.NoEofError;
reader: ReaderType,
decoder: *const Base64DecoderWithIgnore,
pub fn init(reader: ReaderType, decoder: *const Base64DecoderWithIgnore) Self {
return .{ .reader = reader, .decoder = decoder };
}
pub fn next(self: *Self, buffer: []u8) Err![]u8 {
var size: usize = 0;
while (true) {
const byte = self.reader.readByte() catch |err| switch (err) {
Self.Err.EndOfStream => {
break;
},
else => return err,
};
if (self.decoder.char_is_ignored[byte]) {
continue;
}
buffer[size] = byte;
size += 1;
if (size == 4) {
break;
}
}
if (size == 0) {
return Self.Err.EndOfStream;
}
return buffer[0..size];
}
};
}
/// Invalid characters that are not ignored result in error.InvalidCharacter.
/// Invalid padding results in error.InvalidPadding.
/// Decoding more data than can fit in dest results in error.NoSpaceLeft. See also ::calcSizeUpperBound.
/// Returns the number of bytes written to dest.
pub fn decode(decoder_with_ignore: *const Base64DecoderWithIgnore, dest: []u8, source: []const u8) Error!usize {
const decoder = &decoder_with_ignore.decoder;
var acc: u12 = 0;
var acc_len: u4 = 0;
var dest_idx: usize = 0;
var leftover_idx: ?usize = null;
for (source, 0..) |c, src_idx| {
if (decoder_with_ignore.char_is_ignored[c]) continue;
const d = decoder.char_to_index[c];
if (d == Base64Decoder.invalid_char) {
if (decoder.pad_char == null or c != decoder.pad_char.?) return error.InvalidCharacter;
leftover_idx = src_idx;
break;
}
acc = (acc << 6) + d;
acc_len += 6;
if (acc_len >= 8) {
if (dest_idx == dest.len) return error.NoSpaceLeft;
acc_len -= 8;
dest[dest_idx] = @as(u8, @truncate(acc >> acc_len));
dest_idx += 1;
}
var sourceStream = std.io.fixedBufferStream(source);
const sourceReader = sourceStream.reader();
var destStream = std.io.fixedBufferStream(dest);
const DestStreamType = @TypeOf(destStream);
const destWriter = destStream.writer();
decoder_with_ignore.decodeFromReaderToWriter(destWriter, sourceReader) catch |err| switch (err) {
DestStreamType.WriteError.NoSpaceLeft => return error.NoSpaceLeft,
WindowWithIgnore(@TypeOf(sourceReader)).Err.EndOfStream => unreachable,
error.InvalidCharacter, error.InvalidPadding => |e| return e,
};
return destStream.pos;
}
// destWriter must be compatible with std.io.Writer's writeAll interface
pub fn decodeWriter(decoder_with_ignore: *const Base64DecoderWithIgnore, destWriter: anytype, source: []const u8) !void {
var stream = std.io.fixedBufferStream(source);
const reader = stream.reader();
return decoder_with_ignore.decodeFromReaderToWriter(destWriter, reader);
}
// destWriter must be compatible with std.io.Writer's writeAll interface
// sourceReader must be compatible with std.io.Reader's readByte interface
pub fn decodeFromReaderToWriter(decoder_with_ignore: *const Base64DecoderWithIgnore, destWriter: anytype, sourceReader: anytype) !void {
var buffer = [_]u8{0} ** 4;
const WindowType = WindowWithIgnore(@TypeOf(sourceReader));
var chunker = WindowType.init(sourceReader, decoder_with_ignore);
while (chunker.next(&buffer)) |chunk| {
try decoder_with_ignore.decoder.decodeWriter(destWriter, chunk);
} else |err| switch (err) {
WindowType.Err.EndOfStream => return,
else => return err,
}
if (acc_len > 4 or (acc & (@as(u12, 1) << acc_len) - 1) != 0) {
return error.InvalidPadding;
}
const padding_len = acc_len / 2;
if (leftover_idx == null) {
if (decoder.pad_char != null and padding_len != 0) return error.InvalidPadding;
return dest_idx;
}
const leftover = source[leftover_idx.?..];
if (decoder.pad_char) |pad_char| {
var padding_chars: usize = 0;
for (leftover) |c| {
if (decoder_with_ignore.char_is_ignored[c]) continue;
if (c != pad_char) {
return if (c == Base64Decoder.invalid_char) error.InvalidCharacter else error.InvalidPadding;
}
padding_chars += 1;
}
if (padding_chars != padding_len) return error.InvalidPadding;
}
return dest_idx;
}
};
@ -523,20 +578,55 @@ fn testAllApis(codecs: Codecs, expected_decoded: []const u8, expected_encoded: [
// Base64Decoder
{
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)];
try codecs.Decoder.decode(decoded, expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, decoded);
{
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try codecs.Decoder.calcSizeForSlice(expected_encoded)];
try codecs.Decoder.decode(decoded, expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, decoded);
}
//stream version
{
var list = try std.BoundedArray(u8, 0x100).init(0);
try codecs.Decoder.decodeWriter(list.writer(), expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
// from reader to writer version
{
var list = try std.BoundedArray(u8, 0x100).init(0);
var stream = std.io.fixedBufferStream(expected_encoded);
try codecs.Decoder.decodeFromReaderToWriter(list.writer(), stream.reader());
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
}
// Base64DecoderWithIgnore
{
const decoder_ignore_nothing = codecs.decoderWithIgnore("");
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)];
const written = try decoder_ignore_nothing.decode(decoded, expected_encoded);
try testing.expect(written <= decoded.len);
try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]);
{
var buffer: [0x100]u8 = undefined;
const decoded = buffer[0..try decoder_ignore_nothing.calcSizeUpperBound(expected_encoded.len)];
const written = try decoder_ignore_nothing.decode(decoded, expected_encoded);
try testing.expect(written <= decoded.len);
try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]);
}
//stream version
{
var list = try std.BoundedArray(u8, 0x100).init(0);
try decoder_ignore_nothing.decodeWriter(list.writer(), expected_encoded);
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
// from reader to writer
{
var list = try std.BoundedArray(u8, 0x100).init(0);
var stream = std.io.fixedBufferStream(expected_encoded);
try decoder_ignore_nothing.decodeFromReaderToWriter(list.writer(), stream.reader());
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
}
}
@ -546,6 +636,11 @@ fn testDecodeIgnoreSpace(codecs: Codecs, expected_decoded: []const u8, encoded:
const decoded = buffer[0..try decoder_ignore_space.calcSizeUpperBound(encoded.len)];
const written = try decoder_ignore_space.decode(decoded, encoded);
try testing.expectEqualSlices(u8, expected_decoded, decoded[0..written]);
//stream version
var list = try std.BoundedArray(u8, 0x100).init(0);
try decoder_ignore_space.decodeWriter(list.writer(), encoded);
try testing.expectEqualSlices(u8, expected_decoded, list.slice());
}
fn testError(codecs: Codecs, encoded: []const u8, expected_err: anyerror) !void {