mirror of
https://github.com/ziglang/zig.git
synced 2024-11-25 05:40:16 +00:00
add std.http.WebSocket
This commit is contained in:
parent
d36c182748
commit
b9fd0eeca6
@ -4,6 +4,7 @@ pub const protocol = @import("http/protocol.zig");
|
||||
pub const HeadParser = @import("http/HeadParser.zig");
|
||||
pub const ChunkParser = @import("http/ChunkParser.zig");
|
||||
pub const HeaderIterator = @import("http/HeaderIterator.zig");
|
||||
pub const WebSocket = @import("http/WebSocket.zig");
|
||||
|
||||
pub const Version = enum {
|
||||
@"HTTP/1.0",
|
||||
@ -318,6 +319,7 @@ test {
|
||||
_ = Status;
|
||||
_ = HeadParser;
|
||||
_ = ChunkParser;
|
||||
_ = WebSocket;
|
||||
_ = @import("http/test.zig");
|
||||
}
|
||||
}
|
||||
|
243
lib/std/http/WebSocket.zig
Normal file
243
lib/std/http/WebSocket.zig
Normal file
@ -0,0 +1,243 @@
|
||||
//! See https://tools.ietf.org/html/rfc6455
|
||||
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const WebSocket = @This();
|
||||
const assert = std.debug.assert;
|
||||
const native_endian = builtin.cpu.arch.endian();
|
||||
|
||||
key: []const u8,
|
||||
request: *std.http.Server.Request,
|
||||
recv_fifo: std.fifo.LinearFifo(u8, .Slice),
|
||||
reader: std.io.AnyReader,
|
||||
response: std.http.Server.Response,
|
||||
/// Number of bytes that have been peeked but not discarded yet.
|
||||
outstanding_len: usize,
|
||||
|
||||
pub const InitError = error{WebSocketUpgradeMissingKey} ||
|
||||
std.http.Server.Request.ReaderError;
|
||||
|
||||
pub fn init(
|
||||
ws: *WebSocket,
|
||||
request: *std.http.Server.Request,
|
||||
send_buffer: []u8,
|
||||
recv_buffer: []align(4) u8,
|
||||
) InitError!bool {
|
||||
var sec_websocket_key: ?[]const u8 = null;
|
||||
var upgrade_websocket: bool = false;
|
||||
var it = request.iterateHeaders();
|
||||
while (it.next()) |header| {
|
||||
if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) {
|
||||
sec_websocket_key = header.value;
|
||||
} else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) {
|
||||
if (!std.mem.eql(u8, header.value, "websocket"))
|
||||
return false;
|
||||
upgrade_websocket = true;
|
||||
}
|
||||
}
|
||||
if (!upgrade_websocket)
|
||||
return false;
|
||||
|
||||
const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey;
|
||||
|
||||
var sha1 = std.crypto.hash.Sha1.init(.{});
|
||||
sha1.update(key);
|
||||
sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
|
||||
sha1.final(&digest);
|
||||
var base64_digest: [28]u8 = undefined;
|
||||
assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len);
|
||||
|
||||
request.head.content_length = std.math.maxInt(u64);
|
||||
|
||||
ws.* = .{
|
||||
.key = key,
|
||||
.recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer),
|
||||
.reader = try request.reader(),
|
||||
.response = request.respondStreaming(.{
|
||||
.send_buffer = send_buffer,
|
||||
.respond_options = .{
|
||||
.status = .switching_protocols,
|
||||
.extra_headers = &.{
|
||||
.{ .name = "upgrade", .value = "websocket" },
|
||||
.{ .name = "connection", .value = "upgrade" },
|
||||
.{ .name = "sec-websocket-accept", .value = &base64_digest },
|
||||
},
|
||||
.transfer_encoding = .none,
|
||||
},
|
||||
}),
|
||||
.request = request,
|
||||
.outstanding_len = 0,
|
||||
};
|
||||
return true;
|
||||
}
|
||||
|
||||
pub const Header0 = packed struct(u8) {
|
||||
opcode: Opcode,
|
||||
rsv3: u1 = 0,
|
||||
rsv2: u1 = 0,
|
||||
rsv1: u1 = 0,
|
||||
fin: bool,
|
||||
};
|
||||
|
||||
pub const Header1 = packed struct(u8) {
|
||||
payload_len: enum(u7) {
|
||||
len16 = 126,
|
||||
len64 = 127,
|
||||
_,
|
||||
},
|
||||
mask: bool,
|
||||
};
|
||||
|
||||
pub const Opcode = enum(u4) {
|
||||
continuation = 0,
|
||||
text = 1,
|
||||
binary = 2,
|
||||
connection_close = 8,
|
||||
ping = 9,
|
||||
/// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional
|
||||
/// heartbeat. A response to an unsolicited Pong frame is not expected."
|
||||
pong = 10,
|
||||
_,
|
||||
};
|
||||
|
||||
pub const ReadSmallTextMessageError = error{
|
||||
ConnectionClose,
|
||||
UnexpectedOpCode,
|
||||
MessageTooBig,
|
||||
MissingMaskBit,
|
||||
} || RecvError;
|
||||
|
||||
pub const SmallMessage = struct {
|
||||
/// Can be text, binary, or ping.
|
||||
opcode: Opcode,
|
||||
data: []u8,
|
||||
};
|
||||
|
||||
/// Reads the next message from the WebSocket stream, failing if the message does not fit
|
||||
/// into `recv_buffer`.
|
||||
pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage {
|
||||
while (true) {
|
||||
const header_bytes = (try recv(ws, 2))[0..2];
|
||||
const h0: Header0 = @bitCast(header_bytes[0]);
|
||||
const h1: Header1 = @bitCast(header_bytes[1]);
|
||||
|
||||
switch (h0.opcode) {
|
||||
.text, .binary, .pong, .ping => {},
|
||||
.connection_close => return error.ConnectionClose,
|
||||
.continuation => return error.UnexpectedOpCode,
|
||||
_ => return error.UnexpectedOpCode,
|
||||
}
|
||||
|
||||
if (!h0.fin) return error.MessageTooBig;
|
||||
if (!h1.mask) return error.MissingMaskBit;
|
||||
|
||||
const len: usize = switch (h1.payload_len) {
|
||||
.len16 => try recvReadInt(ws, u16),
|
||||
.len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig,
|
||||
else => @intFromEnum(h1.payload_len),
|
||||
};
|
||||
if (len > ws.recv_fifo.buf.len) return error.MessageTooBig;
|
||||
|
||||
const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*);
|
||||
const payload = try recv(ws, len);
|
||||
|
||||
// Skip pongs.
|
||||
if (h0.opcode == .pong) continue;
|
||||
|
||||
// The last item may contain a partial word of unused data.
|
||||
const floored_len = (payload.len / 4) * 4;
|
||||
const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len]));
|
||||
for (u32_payload) |*elem| elem.* ^= mask;
|
||||
const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len];
|
||||
for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m;
|
||||
|
||||
return .{
|
||||
.opcode = h0.opcode,
|
||||
.data = payload,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const RecvError = std.http.Server.Request.ReadError || error{EndOfStream};
|
||||
|
||||
fn recv(ws: *WebSocket, len: usize) RecvError![]u8 {
|
||||
ws.recv_fifo.discard(ws.outstanding_len);
|
||||
assert(len <= ws.recv_fifo.buf.len);
|
||||
if (len > ws.recv_fifo.count) {
|
||||
const small_buf = ws.recv_fifo.writableSlice(0);
|
||||
const needed = len - ws.recv_fifo.count;
|
||||
const buf = if (small_buf.len >= needed) small_buf else b: {
|
||||
ws.recv_fifo.realign();
|
||||
break :b ws.recv_fifo.writableSlice(0);
|
||||
};
|
||||
const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed)));
|
||||
if (n < needed) return error.EndOfStream;
|
||||
ws.recv_fifo.update(n);
|
||||
}
|
||||
ws.outstanding_len = len;
|
||||
// TODO: improve the std lib API so this cast isn't necessary.
|
||||
return @constCast(ws.recv_fifo.readableSliceOfLen(len));
|
||||
}
|
||||
|
||||
fn recvReadInt(ws: *WebSocket, comptime I: type) !I {
|
||||
const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*);
|
||||
return switch (native_endian) {
|
||||
.little => @byteSwap(unswapped),
|
||||
.big => unswapped,
|
||||
};
|
||||
}
|
||||
|
||||
pub const WriteError = std.http.Server.Response.WriteError;
|
||||
|
||||
pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) WriteError!void {
|
||||
const iovecs: [1]std.posix.iovec_const = .{
|
||||
.{ .base = message.ptr, .len = message.len },
|
||||
};
|
||||
return writeMessagev(ws, &iovecs, opcode);
|
||||
}
|
||||
|
||||
pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) WriteError!void {
|
||||
const total_len = l: {
|
||||
var total_len: u64 = 0;
|
||||
for (message) |iovec| total_len += iovec.len;
|
||||
break :l total_len;
|
||||
};
|
||||
|
||||
var header_buf: [2 + 8]u8 = undefined;
|
||||
header_buf[0] = @bitCast(@as(Header0, .{
|
||||
.opcode = opcode,
|
||||
.fin = true,
|
||||
}));
|
||||
const header = switch (total_len) {
|
||||
0...125 => blk: {
|
||||
header_buf[1] = @bitCast(@as(Header1, .{
|
||||
.payload_len = @enumFromInt(total_len),
|
||||
.mask = false,
|
||||
}));
|
||||
break :blk header_buf[0..2];
|
||||
},
|
||||
126...0xffff => blk: {
|
||||
header_buf[1] = @bitCast(@as(Header1, .{
|
||||
.payload_len = .len16,
|
||||
.mask = false,
|
||||
}));
|
||||
std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big);
|
||||
break :blk header_buf[0..4];
|
||||
},
|
||||
else => blk: {
|
||||
header_buf[1] = @bitCast(@as(Header1, .{
|
||||
.payload_len = .len64,
|
||||
.mask = false,
|
||||
}));
|
||||
std.mem.writeInt(u64, header_buf[2..10], total_len, .big);
|
||||
break :blk header_buf[0..10];
|
||||
},
|
||||
};
|
||||
|
||||
const response = &ws.response;
|
||||
try response.writeAll(header);
|
||||
for (message) |iovec|
|
||||
try response.writeAll(iovec.base[0..iovec.len]);
|
||||
try response.flush();
|
||||
}
|
Loading…
Reference in New Issue
Block a user