std.net, std.http: simplify

This commit is contained in:
Andrew Kelley 2024-02-18 20:22:09 -07:00
parent f1565e3d09
commit 6129ecd4fe
10 changed files with 714 additions and 841 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,132 @@
stream: std.net.Stream,
protocol: Protocol,
closing: bool,
read_buf: [buffer_size]u8,
read_start: u16,
read_end: u16,
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
pub const Protocol = enum { plain };
pub fn rawReadAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
return switch (conn.protocol) {
.plain => conn.stream.readAtLeast(buffer, len),
// .tls => conn.tls_client.readAtLeast(conn.stream, buffer, len),
} catch |err| {
switch (err) {
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
}
};
}
pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return;
const nread = try conn.rawReadAtLeast(conn.read_buf[0..], 1);
if (nread == 0) return error.EndOfStream;
conn.read_start = 0;
conn.read_end = @intCast(nread);
}
pub fn peek(conn: *Connection) []const u8 {
return conn.read_buf[conn.read_start..conn.read_end];
}
pub fn drop(conn: *Connection, num: u16) void {
conn.read_start += num;
}
pub fn readAtLeast(conn: *Connection, buffer: []u8, len: usize) ReadError!usize {
assert(len <= buffer.len);
var out_index: u16 = 0;
while (out_index < len) {
const available_read = conn.read_end - conn.read_start;
const available_buffer = buffer.len - out_index;
if (available_read > available_buffer) { // partially read buffered data
@memcpy(buffer[out_index..], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
out_index += @as(u16, @intCast(available_buffer));
conn.read_start += @as(u16, @intCast(available_buffer));
break;
} else if (available_read > 0) { // fully read buffered data
@memcpy(buffer[out_index..][0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
out_index += available_read;
conn.read_start += available_read;
if (out_index >= len) break;
}
const leftover_buffer = available_buffer - available_read;
const leftover_len = len - out_index;
if (leftover_buffer > conn.read_buf.len) {
// skip the buffer if the output is large enough
return conn.rawReadAtLeast(buffer[out_index..], leftover_len);
}
try conn.fill();
}
return out_index;
}
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
return conn.readAtLeast(buffer, 1);
}
pub const ReadError = error{
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
EndOfStream,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
pub fn reader(conn: *Connection) Reader {
return .{ .context = conn };
}
pub fn writeAll(conn: *Connection, buffer: []const u8) WriteError!void {
return switch (conn.protocol) {
.plain => conn.stream.writeAll(buffer),
// .tls => return conn.tls_client.writeAll(conn.stream, buffer),
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
return switch (conn.protocol) {
.plain => conn.stream.write(buffer),
// .tls => return conn.tls_client.write(conn.stream, buffer),
} catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
return .{ .context = conn };
}
pub fn close(conn: *Connection) void {
conn.stream.close();
}
const Connection = @This();
const std = @import("../../std.zig");
const assert = std.debug.assert;

View File

@ -8,13 +8,12 @@ test "trailers" {
const gpa = testing.allocator; const gpa = testing.allocator;
var http_server = std.http.Server.init(.{ const address = try std.net.Address.parseIp("127.0.0.1", 0);
var http_server = try address.listen(.{
.reuse_address = true, .reuse_address = true,
}); });
const address = try std.net.Address.parseIp("127.0.0.1", 0);
try http_server.listen(address);
const port = http_server.socket.listen_address.in.getPort(); const port = http_server.listen_address.in.getPort();
const server_thread = try std.Thread.spawn(.{}, serverThread, .{&http_server}); const server_thread = try std.Thread.spawn(.{}, serverThread, .{&http_server});
defer server_thread.join(); defer server_thread.join();
@ -67,17 +66,14 @@ test "trailers" {
try testing.expect(client.connection_pool.free_len == 1); try testing.expect(client.connection_pool.free_len == 1);
} }
fn serverThread(http_server: *std.http.Server) anyerror!void { fn serverThread(http_server: *std.net.Server) anyerror!void {
const gpa = testing.allocator;
var header_buffer: [1024]u8 = undefined; var header_buffer: [1024]u8 = undefined;
var remaining: usize = 1; var remaining: usize = 1;
accept: while (remaining != 0) : (remaining -= 1) { accept: while (remaining != 0) : (remaining -= 1) {
var res = try http_server.accept(.{ const conn = try http_server.accept();
.allocator = gpa, defer conn.stream.close();
.client_header_buffer = &header_buffer,
}); var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer });
defer res.deinit();
res.wait() catch |err| switch (err) { res.wait() catch |err| switch (err) {
error.HttpHeadersInvalid => continue :accept, error.HttpHeadersInvalid => continue :accept,
@ -90,7 +86,7 @@ fn serverThread(http_server: *std.http.Server) anyerror!void {
} }
} }
fn serve(res: *std.http.Server.Response) !void { fn serve(res: *std.http.Server) !void {
try testing.expectEqualStrings(res.request.target, "/trailer"); try testing.expectEqualStrings(res.request.target, "/trailer");
res.transfer_encoding = .chunked; res.transfer_encoding = .chunked;
@ -99,3 +95,73 @@ fn serve(res: *std.http.Server.Response) !void {
try res.writeAll("World!\n"); try res.writeAll("World!\n");
try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n"); try res.connection.writeAll("0\r\nX-Checksum: aaaa\r\n\r\n");
} }
test "HTTP server handles a chunked transfer coding request" {
// This test requires spawning threads.
if (builtin.single_threaded) {
return error.SkipZigTest;
}
const native_endian = comptime builtin.cpu.arch.endian();
if (builtin.zig_backend == .stage2_llvm and native_endian == .big) {
// https://github.com/ziglang/zig/issues/13782
return error.SkipZigTest;
}
if (builtin.os.tag == .wasi) return error.SkipZigTest;
const allocator = std.testing.allocator;
const expect = std.testing.expect;
const max_header_size = 8192;
const address = try std.net.Address.parseIp("127.0.0.1", 0);
var server = try address.listen(.{ .reuse_address = true });
defer server.deinit();
const server_port = server.listen_address.in.getPort();
const server_thread = try std.Thread.spawn(.{}, (struct {
fn apply(s: *std.net.Server) !void {
var header_buffer: [max_header_size]u8 = undefined;
const conn = try s.accept();
defer conn.stream.close();
var res = std.http.Server.init(conn, .{ .client_header_buffer = &header_buffer });
try res.wait();
try expect(res.request.transfer_encoding == .chunked);
const server_body: []const u8 = "message from server!\n";
res.transfer_encoding = .{ .content_length = server_body.len };
res.extra_headers = &.{
.{ .name = "content-type", .value = "text/plain" },
};
res.keep_alive = false;
try res.send();
var buf: [128]u8 = undefined;
const n = try res.readAll(&buf);
try expect(std.mem.eql(u8, buf[0..n], "ABCD"));
_ = try res.writer().writeAll(server_body);
try res.finish();
}
}).apply, .{&server});
const request_bytes =
"POST / HTTP/1.1\r\n" ++
"Content-Type: text/plain\r\n" ++
"Transfer-Encoding: chunked\r\n" ++
"\r\n" ++
"1\r\n" ++
"A\r\n" ++
"1\r\n" ++
"B\r\n" ++
"2\r\n" ++
"CD\r\n" ++
"0\r\n" ++
"\r\n";
const stream = try std.net.tcpConnectToHost(allocator, "127.0.0.1", server_port);
defer stream.close();
_ = try stream.writeAll(request_bytes[0..]);
server_thread.join();
}

View File

@ -4,15 +4,17 @@ const assert = std.debug.assert;
const net = @This(); const net = @This();
const mem = std.mem; const mem = std.mem;
const os = std.os; const os = std.os;
const posix = std.posix;
const fs = std.fs; const fs = std.fs;
const io = std.io; const io = std.io;
const native_endian = builtin.target.cpu.arch.endian(); const native_endian = builtin.target.cpu.arch.endian();
// Windows 10 added support for unix sockets in build 17063, redstone 4 is the // Windows 10 added support for unix sockets in build 17063, redstone 4 is the
// first release to support them. // first release to support them.
pub const has_unix_sockets = @hasDecl(os.sockaddr, "un") and pub const has_unix_sockets = switch (builtin.os.tag) {
(builtin.target.os.tag != .windows or .windows => builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false,
builtin.os.version_range.windows.isAtLeast(.win10_rs4) orelse false); else => true,
};
pub const IPParseError = error{ pub const IPParseError = error{
Overflow, Overflow,
@ -206,6 +208,57 @@ pub const Address = extern union {
else => unreachable, else => unreachable,
} }
} }
pub const ListenError = posix.SocketError || posix.BindError || posix.ListenError ||
posix.SetSockOptError || posix.GetSockNameError;
pub const ListenOptions = struct {
/// How many connections the kernel will accept on the application's behalf.
/// If more than this many connections pool in the kernel, clients will start
/// seeing "Connection refused".
kernel_backlog: u31 = 128,
reuse_address: bool = false,
reuse_port: bool = false,
force_nonblocking: bool = false,
};
/// The returned `Server` has an open `stream`.
pub fn listen(address: Address, options: ListenOptions) ListenError!Server {
const nonblock: u32 = if (options.force_nonblocking) posix.SOCK.NONBLOCK else 0;
const sock_flags = posix.SOCK.STREAM | posix.SOCK.CLOEXEC | nonblock;
const proto: u32 = if (address.any.family == posix.AF.UNIX) 0 else posix.IPPROTO.TCP;
const sockfd = try posix.socket(address.any.family, sock_flags, proto);
var s: Server = .{
.listen_address = undefined,
.stream = .{ .handle = sockfd },
};
errdefer s.stream.close();
if (options.reuse_address) {
try posix.setsockopt(
sockfd,
posix.SOL.SOCKET,
posix.SO.REUSEADDR,
&mem.toBytes(@as(c_int, 1)),
);
}
if (options.reuse_port) {
try posix.setsockopt(
sockfd,
posix.SOL.SOCKET,
posix.SO.REUSEPORT,
&mem.toBytes(@as(c_int, 1)),
);
}
var socklen = address.getOsSockLen();
try posix.bind(sockfd, &address.any, socklen);
try posix.listen(sockfd, options.kernel_backlog);
try posix.getsockname(sockfd, &s.listen_address.any, &socklen);
return s;
}
}; };
pub const Ip4Address = extern struct { pub const Ip4Address = extern struct {
@ -657,7 +710,7 @@ pub fn connectUnixSocket(path: []const u8) !Stream {
os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block, os.SOCK.STREAM | os.SOCK.CLOEXEC | opt_non_block,
0, 0,
); );
errdefer os.closeSocket(sockfd); errdefer Stream.close(.{ .handle = sockfd });
var addr = try std.net.Address.initUnix(path); var addr = try std.net.Address.initUnix(path);
try os.connect(sockfd, &addr.any, addr.getOsSockLen()); try os.connect(sockfd, &addr.any, addr.getOsSockLen());
@ -669,7 +722,7 @@ fn if_nametoindex(name: []const u8) IPv6InterfaceError!u32 {
if (builtin.target.os.tag == .linux) { if (builtin.target.os.tag == .linux) {
var ifr: os.ifreq = undefined; var ifr: os.ifreq = undefined;
const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0); const sockfd = try os.socket(os.AF.UNIX, os.SOCK.DGRAM | os.SOCK.CLOEXEC, 0);
defer os.closeSocket(sockfd); defer Stream.close(.{ .handle = sockfd });
@memcpy(ifr.ifrn.name[0..name.len], name); @memcpy(ifr.ifrn.name[0..name.len], name);
ifr.ifrn.name[name.len] = 0; ifr.ifrn.name[name.len] = 0;
@ -738,7 +791,7 @@ pub fn tcpConnectToAddress(address: Address) TcpConnectToAddressError!Stream {
const sock_flags = os.SOCK.STREAM | nonblock | const sock_flags = os.SOCK.STREAM | nonblock |
(if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC); (if (builtin.target.os.tag == .windows) 0 else os.SOCK.CLOEXEC);
const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP); const sockfd = try os.socket(address.any.family, sock_flags, os.IPPROTO.TCP);
errdefer os.closeSocket(sockfd); errdefer Stream.close(.{ .handle = sockfd });
try os.connect(sockfd, &address.any, address.getOsSockLen()); try os.connect(sockfd, &address.any, address.getOsSockLen());
@ -1068,7 +1121,7 @@ fn linuxLookupName(
var prefixlen: i32 = 0; var prefixlen: i32 = 0;
const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC; const sock_flags = os.SOCK.DGRAM | os.SOCK.CLOEXEC;
if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: { if (os.socket(addr.addr.any.family, sock_flags, os.IPPROTO.UDP)) |fd| syscalls: {
defer os.closeSocket(fd); defer Stream.close(.{ .handle = fd });
os.connect(fd, da, dalen) catch break :syscalls; os.connect(fd, da, dalen) catch break :syscalls;
key |= DAS_USABLE; key |= DAS_USABLE;
os.getsockname(fd, sa, &salen) catch break :syscalls; os.getsockname(fd, sa, &salen) catch break :syscalls;
@ -1553,7 +1606,7 @@ fn resMSendRc(
}, },
else => |e| return e, else => |e| return e,
}; };
defer os.closeSocket(fd); defer Stream.close(.{ .handle = fd });
// Past this point, there are no errors. Each individual query will // Past this point, there are no errors. Each individual query will
// yield either no reply (indicated by zero length) or an answer // yield either no reply (indicated by zero length) or an answer
@ -1729,13 +1782,15 @@ fn dnsParseCallback(ctx: dpc_ctx, rr: u8, data: []const u8, packet: []const u8)
} }
pub const Stream = struct { pub const Stream = struct {
// Underlying socket descriptor. /// Underlying platform-defined type which may or may not be
// Note that on some platforms this may not be interchangeable with a /// interchangeable with a file system file descriptor.
// regular files descriptor. handle: posix.socket_t,
handle: os.socket_t,
pub fn close(self: Stream) void { pub fn close(s: Stream) void {
os.closeSocket(self.handle); switch (builtin.os.tag) {
.windows => std.os.windows.closesocket(s.handle) catch unreachable,
else => posix.close(s.handle),
}
} }
pub const ReadError = os.ReadError; pub const ReadError = os.ReadError;
@ -1839,156 +1894,38 @@ pub const Stream = struct {
} }
}; };
pub const StreamServer = struct { pub const Server = struct {
/// Copied from `Options` on `init`.
kernel_backlog: u31,
reuse_address: bool,
reuse_port: bool,
force_nonblocking: bool,
/// `undefined` until `listen` returns successfully.
listen_address: Address, listen_address: Address,
stream: std.net.Stream,
sockfd: ?os.socket_t,
pub const Options = struct {
/// How many connections the kernel will accept on the application's behalf.
/// If more than this many connections pool in the kernel, clients will start
/// seeing "Connection refused".
kernel_backlog: u31 = 128,
/// Enable SO.REUSEADDR on the socket.
reuse_address: bool = false,
/// Enable SO.REUSEPORT on the socket.
reuse_port: bool = false,
/// Force non-blocking mode.
force_nonblocking: bool = false,
};
/// After this call succeeds, resources have been acquired and must
/// be released with `deinit`.
pub fn init(options: Options) StreamServer {
return StreamServer{
.sockfd = null,
.kernel_backlog = options.kernel_backlog,
.reuse_address = options.reuse_address,
.reuse_port = options.reuse_port,
.force_nonblocking = options.force_nonblocking,
.listen_address = undefined,
};
}
/// Release all resources. The `StreamServer` memory becomes `undefined`.
pub fn deinit(self: *StreamServer) void {
self.close();
self.* = undefined;
}
pub fn listen(self: *StreamServer, address: Address) !void {
const nonblock = 0;
const sock_flags = os.SOCK.STREAM | os.SOCK.CLOEXEC | nonblock;
var use_sock_flags: u32 = sock_flags;
if (self.force_nonblocking) use_sock_flags |= os.SOCK.NONBLOCK;
const proto = if (address.any.family == os.AF.UNIX) @as(u32, 0) else os.IPPROTO.TCP;
const sockfd = try os.socket(address.any.family, use_sock_flags, proto);
self.sockfd = sockfd;
errdefer {
os.closeSocket(sockfd);
self.sockfd = null;
}
if (self.reuse_address) {
try os.setsockopt(
sockfd,
os.SOL.SOCKET,
os.SO.REUSEADDR,
&mem.toBytes(@as(c_int, 1)),
);
}
if (@hasDecl(os.SO, "REUSEPORT") and self.reuse_port) {
try os.setsockopt(
sockfd,
os.SOL.SOCKET,
os.SO.REUSEPORT,
&mem.toBytes(@as(c_int, 1)),
);
}
var socklen = address.getOsSockLen();
try os.bind(sockfd, &address.any, socklen);
try os.listen(sockfd, self.kernel_backlog);
try os.getsockname(sockfd, &self.listen_address.any, &socklen);
}
/// Stop listening. It is still necessary to call `deinit` after stopping listening.
/// Calling `deinit` will automatically call `close`. It is safe to call `close` when
/// not listening.
pub fn close(self: *StreamServer) void {
if (self.sockfd) |fd| {
os.closeSocket(fd);
self.sockfd = null;
self.listen_address = undefined;
}
}
pub const AcceptError = error{
ConnectionAborted,
/// The per-process limit on the number of open file descriptors has been reached.
ProcessFdQuotaExceeded,
/// The system-wide limit on the total number of open files has been reached.
SystemFdQuotaExceeded,
/// Not enough free memory. This often means that the memory allocation
/// is limited by the socket buffer limits, not by the system memory.
SystemResources,
/// Socket is not listening for new connections.
SocketNotListening,
ProtocolFailure,
/// Socket is in non-blocking mode and there is no connection to accept.
WouldBlock,
/// Firewall rules forbid connection.
BlockedByFirewall,
FileDescriptorNotASocket,
ConnectionResetByPeer,
NetworkSubsystemFailed,
OperationNotSupported,
} || os.UnexpectedError;
pub const Connection = struct { pub const Connection = struct {
stream: Stream, stream: std.net.Stream,
address: Address, address: Address,
}; };
/// If this function succeeds, the returned `Connection` is a caller-managed resource. pub fn deinit(s: *Server) void {
pub fn accept(self: *StreamServer) AcceptError!Connection { s.stream.close();
var accepted_addr: Address = undefined; s.* = undefined;
var adr_len: os.socklen_t = @sizeOf(Address); }
const accept_result = os.accept(self.sockfd.?, &accepted_addr.any, &adr_len, os.SOCK.CLOEXEC);
if (accept_result) |fd| { pub const AcceptError = posix.AcceptError;
return Connection{
.stream = Stream{ .handle = fd }, /// Blocks until a client connects to the server. The returned `Connection` has
.address = accepted_addr, /// an open stream.
}; pub fn accept(s: *Server) AcceptError!Connection {
} else |err| { var accepted_addr: Address = undefined;
return err; var addr_len: posix.socklen_t = @sizeOf(Address);
} const fd = try posix.accept(s.stream.handle, &accepted_addr.any, &addr_len, posix.SOCK.CLOEXEC);
return .{
.stream = .{ .handle = fd },
.address = accepted_addr,
};
} }
}; };
test { test {
_ = @import("net/test.zig"); _ = @import("net/test.zig");
_ = Server;
_ = Stream;
_ = Address;
} }

View File

@ -181,11 +181,9 @@ test "listen on a port, send bytes, receive bytes" {
// configured. // configured.
const localhost = try net.Address.parseIp("127.0.0.1", 0); const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server = net.StreamServer.init(.{}); var server = try localhost.listen(.{});
defer server.deinit(); defer server.deinit();
try server.listen(localhost);
const S = struct { const S = struct {
fn clientFn(server_address: net.Address) !void { fn clientFn(server_address: net.Address) !void {
const socket = try net.tcpConnectToAddress(server_address); const socket = try net.tcpConnectToAddress(server_address);
@ -215,17 +213,11 @@ test "listen on an in use port" {
const localhost = try net.Address.parseIp("127.0.0.1", 0); const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server1 = net.StreamServer.init(net.StreamServer.Options{ var server1 = try localhost.listen(.{ .reuse_port = true });
.reuse_port = true,
});
defer server1.deinit(); defer server1.deinit();
try server1.listen(localhost);
var server2 = net.StreamServer.init(net.StreamServer.Options{ var server2 = try server1.listen_address.listen(.{ .reuse_port = true });
.reuse_port = true,
});
defer server2.deinit(); defer server2.deinit();
try server2.listen(server1.listen_address);
} }
fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void { fn testClientToHost(allocator: mem.Allocator, name: []const u8, port: u16) anyerror!void {
@ -252,7 +244,7 @@ fn testClient(addr: net.Address) anyerror!void {
try testing.expect(mem.eql(u8, msg, "hello from server\n")); try testing.expect(mem.eql(u8, msg, "hello from server\n"));
} }
fn testServer(server: *net.StreamServer) anyerror!void { fn testServer(server: *net.Server) anyerror!void {
if (builtin.os.tag == .wasi) return error.SkipZigTest; if (builtin.os.tag == .wasi) return error.SkipZigTest;
var client = try server.accept(); var client = try server.accept();
@ -274,15 +266,14 @@ test "listen on a unix socket, send bytes, receive bytes" {
} }
} }
var server = net.StreamServer.init(.{});
defer server.deinit();
const socket_path = try generateFileName("socket.unix"); const socket_path = try generateFileName("socket.unix");
defer testing.allocator.free(socket_path); defer testing.allocator.free(socket_path);
const socket_addr = try net.Address.initUnix(socket_path); const socket_addr = try net.Address.initUnix(socket_path);
defer std.fs.cwd().deleteFile(socket_path) catch {}; defer std.fs.cwd().deleteFile(socket_path) catch {};
try server.listen(socket_addr);
var server = try socket_addr.listen(.{});
defer server.deinit();
const S = struct { const S = struct {
fn clientFn(path: []const u8) !void { fn clientFn(path: []const u8) !void {
@ -323,9 +314,8 @@ test "non-blocking tcp server" {
} }
const localhost = try net.Address.parseIp("127.0.0.1", 0); const localhost = try net.Address.parseIp("127.0.0.1", 0);
var server = net.StreamServer.init(.{ .force_nonblocking = true }); var server = localhost.listen(.{ .force_nonblocking = true });
defer server.deinit(); defer server.deinit();
try server.listen(localhost);
const accept_err = server.accept(); const accept_err = server.accept();
try testing.expectError(error.WouldBlock, accept_err); try testing.expectError(error.WouldBlock, accept_err);

View File

@ -3598,14 +3598,6 @@ pub fn shutdown(sock: socket_t, how: ShutdownHow) ShutdownError!void {
} }
} }
pub fn closeSocket(sock: socket_t) void {
if (builtin.os.tag == .windows) {
windows.closesocket(sock) catch unreachable;
} else {
close(sock);
}
}
pub const BindError = error{ pub const BindError = error{
/// The address is protected, and the user is not the superuser. /// The address is protected, and the user is not the superuser.
/// For UNIX domain sockets: Search permission is denied on a component /// For UNIX domain sockets: Search permission is denied on a component

View File

@ -4,6 +4,7 @@ const assert = std.debug.assert;
const mem = std.mem; const mem = std.mem;
const net = std.net; const net = std.net;
const os = std.os; const os = std.os;
const posix = std.posix;
const linux = os.linux; const linux = os.linux;
const testing = std.testing; const testing = std.testing;
@ -3730,8 +3731,8 @@ const SocketTestHarness = struct {
client: os.socket_t, client: os.socket_t,
fn close(self: SocketTestHarness) void { fn close(self: SocketTestHarness) void {
os.closeSocket(self.client); posix.close(self.client);
os.closeSocket(self.listener); posix.close(self.listener);
} }
}; };
@ -3739,7 +3740,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness {
// Create a TCP server socket // Create a TCP server socket
var address = try net.Address.parseIp4("127.0.0.1", 0); var address = try net.Address.parseIp4("127.0.0.1", 0);
const listener_socket = try createListenerSocket(&address); const listener_socket = try createListenerSocket(&address);
errdefer os.closeSocket(listener_socket); errdefer posix.close(listener_socket);
// Submit 1 accept // Submit 1 accept
var accept_addr: os.sockaddr = undefined; var accept_addr: os.sockaddr = undefined;
@ -3748,7 +3749,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness {
// Create a TCP client socket // Create a TCP client socket
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
errdefer os.closeSocket(client); errdefer posix.close(client);
_ = try ring.connect(0xcccccccc, client, &address.any, address.getOsSockLen()); _ = try ring.connect(0xcccccccc, client, &address.any, address.getOsSockLen());
try testing.expectEqual(@as(u32, 2), try ring.submit()); try testing.expectEqual(@as(u32, 2), try ring.submit());
@ -3788,7 +3789,7 @@ fn createSocketTestHarness(ring: *IO_Uring) !SocketTestHarness {
fn createListenerSocket(address: *net.Address) !os.socket_t { fn createListenerSocket(address: *net.Address) !os.socket_t {
const kernel_backlog = 1; const kernel_backlog = 1;
const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const listener_socket = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
errdefer os.closeSocket(listener_socket); errdefer posix.close(listener_socket);
try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1))); try os.setsockopt(listener_socket, os.SOL.SOCKET, os.SO.REUSEADDR, &mem.toBytes(@as(c_int, 1)));
try os.bind(listener_socket, &address.any, address.getOsSockLen()); try os.bind(listener_socket, &address.any, address.getOsSockLen());
@ -3813,7 +3814,7 @@ test "accept multishot" {
var address = try net.Address.parseIp4("127.0.0.1", 0); var address = try net.Address.parseIp4("127.0.0.1", 0);
const listener_socket = try createListenerSocket(&address); const listener_socket = try createListenerSocket(&address);
defer os.closeSocket(listener_socket); defer posix.close(listener_socket);
// submit multishot accept operation // submit multishot accept operation
var addr: os.sockaddr = undefined; var addr: os.sockaddr = undefined;
@ -3826,7 +3827,7 @@ test "accept multishot" {
while (nr > 0) : (nr -= 1) { while (nr > 0) : (nr -= 1) {
// connect client // connect client
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
errdefer os.closeSocket(client); errdefer posix.close(client);
try os.connect(client, &address.any, address.getOsSockLen()); try os.connect(client, &address.any, address.getOsSockLen());
// test accept completion // test accept completion
@ -3836,7 +3837,7 @@ test "accept multishot" {
try testing.expect(cqe.user_data == userdata); try testing.expect(cqe.user_data == userdata);
try testing.expect(cqe.flags & linux.IORING_CQE_F_MORE > 0); // more flag is set try testing.expect(cqe.flags & linux.IORING_CQE_F_MORE > 0); // more flag is set
os.closeSocket(client); posix.close(client);
} }
} }
@ -3909,7 +3910,7 @@ test "accept_direct" {
try ring.register_files(registered_fds[0..]); try ring.register_files(registered_fds[0..]);
const listener_socket = try createListenerSocket(&address); const listener_socket = try createListenerSocket(&address);
defer os.closeSocket(listener_socket); defer posix.close(listener_socket);
const accept_userdata: u64 = 0xaaaaaaaa; const accept_userdata: u64 = 0xaaaaaaaa;
const read_userdata: u64 = 0xbbbbbbbb; const read_userdata: u64 = 0xbbbbbbbb;
@ -3927,7 +3928,7 @@ test "accept_direct" {
// connect // connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen()); try os.connect(client, &address.any, address.getOsSockLen());
defer os.closeSocket(client); defer posix.close(client);
// accept completion // accept completion
const cqe_accept = try ring.copy_cqe(); const cqe_accept = try ring.copy_cqe();
@ -3961,7 +3962,7 @@ test "accept_direct" {
// connect // connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen()); try os.connect(client, &address.any, address.getOsSockLen());
defer os.closeSocket(client); defer posix.close(client);
// completion with error // completion with error
const cqe_accept = try ring.copy_cqe(); const cqe_accept = try ring.copy_cqe();
try testing.expect(cqe_accept.user_data == accept_userdata); try testing.expect(cqe_accept.user_data == accept_userdata);
@ -3989,7 +3990,7 @@ test "accept_multishot_direct" {
try ring.register_files(registered_fds[0..]); try ring.register_files(registered_fds[0..]);
const listener_socket = try createListenerSocket(&address); const listener_socket = try createListenerSocket(&address);
defer os.closeSocket(listener_socket); defer posix.close(listener_socket);
const accept_userdata: u64 = 0xaaaaaaaa; const accept_userdata: u64 = 0xaaaaaaaa;
@ -4003,7 +4004,7 @@ test "accept_multishot_direct" {
// connect // connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen()); try os.connect(client, &address.any, address.getOsSockLen());
defer os.closeSocket(client); defer posix.close(client);
// accept completion // accept completion
const cqe_accept = try ring.copy_cqe(); const cqe_accept = try ring.copy_cqe();
@ -4018,7 +4019,7 @@ test "accept_multishot_direct" {
// connect // connect
const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0); const client = try os.socket(address.any.family, os.SOCK.STREAM | os.SOCK.CLOEXEC, 0);
try os.connect(client, &address.any, address.getOsSockLen()); try os.connect(client, &address.any, address.getOsSockLen());
defer os.closeSocket(client); defer posix.close(client);
// completion with error // completion with error
const cqe_accept = try ring.copy_cqe(); const cqe_accept = try ring.copy_cqe();
try testing.expect(cqe_accept.user_data == accept_userdata); try testing.expect(cqe_accept.user_data == accept_userdata);
@ -4092,7 +4093,7 @@ test "socket_direct/socket_direct_alloc/close_direct" {
// use sockets from registered_fds in connect operation // use sockets from registered_fds in connect operation
var address = try net.Address.parseIp4("127.0.0.1", 0); var address = try net.Address.parseIp4("127.0.0.1", 0);
const listener_socket = try createListenerSocket(&address); const listener_socket = try createListenerSocket(&address);
defer os.closeSocket(listener_socket); defer posix.close(listener_socket);
const accept_userdata: u64 = 0xaaaaaaaa; const accept_userdata: u64 = 0xaaaaaaaa;
const connect_userdata: u64 = 0xbbbbbbbb; const connect_userdata: u64 = 0xbbbbbbbb;
const close_userdata: u64 = 0xcccccccc; const close_userdata: u64 = 0xcccccccc;

View File

@ -817,7 +817,7 @@ test "shutdown socket" {
error.SocketNotConnected => {}, error.SocketNotConnected => {},
else => |e| return e, else => |e| return e,
}; };
os.closeSocket(sock); std.net.Stream.close(.{ .handle = sock });
} }
test "sigaction" { test "sigaction" {

View File

@ -3322,13 +3322,13 @@ fn buildOutputType(
.ip4 => |ip4_addr| { .ip4 => |ip4_addr| {
if (build_options.only_core_functionality) unreachable; if (build_options.only_core_functionality) unreachable;
var server = std.net.StreamServer.init(.{ const addr: std.net.Address = .{ .in = ip4_addr };
var server = try addr.listen(.{
.reuse_address = true, .reuse_address = true,
}); });
defer server.deinit(); defer server.deinit();
try server.listen(.{ .in = ip4_addr });
const conn = try server.accept(); const conn = try server.accept();
defer conn.stream.close(); defer conn.stream.close();

View File

@ -1,8 +1,6 @@
const std = @import("std"); const std = @import("std");
const http = std.http; const http = std.http;
const Server = http.Server;
const Client = http.Client;
const mem = std.mem; const mem = std.mem;
const testing = std.testing; const testing = std.testing;
@ -19,9 +17,7 @@ var gpa_client = std.heap.GeneralPurposeAllocator(.{ .stack_trace_frames = 12 })
const salloc = gpa_server.allocator(); const salloc = gpa_server.allocator();
const calloc = gpa_client.allocator(); const calloc = gpa_client.allocator();
var server: Server = undefined; fn handleRequest(res: *http.Server, listen_port: u16) !void {
fn handleRequest(res: *Server.Response) !void {
const log = std.log.scoped(.server); const log = std.log.scoped(.server);
log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target }); log.info("{} {s} {s}", .{ res.request.method, @tagName(res.request.version), res.request.target });
@ -125,7 +121,9 @@ fn handleRequest(res: *Server.Response) !void {
} else if (mem.eql(u8, res.request.target, "/redirect/3")) { } else if (mem.eql(u8, res.request.target, "/redirect/3")) {
res.transfer_encoding = .chunked; res.transfer_encoding = .chunked;
const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{server.socket.listen_address.getPort()}); const location = try std.fmt.allocPrint(salloc, "http://127.0.0.1:{d}/redirect/2", .{
listen_port,
});
defer salloc.free(location); defer salloc.free(location);
res.status = .found; res.status = .found;
@ -168,14 +166,15 @@ fn handleRequest(res: *Server.Response) !void {
var handle_new_requests = true; var handle_new_requests = true;
fn runServer(srv: *Server) !void { fn runServer(server: *std.net.Server) !void {
var client_header_buffer: [1024]u8 = undefined; var client_header_buffer: [1024]u8 = undefined;
outer: while (handle_new_requests) { outer: while (handle_new_requests) {
var res = try srv.accept(.{ var connection = try server.accept();
.allocator = salloc, defer connection.stream.close();
var res = http.Server.init(connection, .{
.client_header_buffer = &client_header_buffer, .client_header_buffer = &client_header_buffer,
}); });
defer res.deinit();
while (res.reset() != .closing) { while (res.reset() != .closing) {
res.wait() catch |err| switch (err) { res.wait() catch |err| switch (err) {
@ -184,16 +183,15 @@ fn runServer(srv: *Server) !void {
else => return err, else => return err,
}; };
try handleRequest(&res); try handleRequest(&res, server.listen_address.getPort());
} }
} }
} }
fn serverThread(srv: *Server) void { fn serverThread(server: *std.net.Server) void {
defer srv.deinit();
defer _ = gpa_server.deinit(); defer _ = gpa_server.deinit();
runServer(srv) catch |err| { runServer(server) catch |err| {
std.debug.print("server error: {}\n", .{err}); std.debug.print("server error: {}\n", .{err});
if (@errorReturnTrace()) |trace| { if (@errorReturnTrace()) |trace| {
@ -205,18 +203,10 @@ fn serverThread(srv: *Server) void {
}; };
} }
fn killServer(addr: std.net.Address) void {
handle_new_requests = false;
const conn = std.net.tcpConnectToAddress(addr) catch return;
conn.close();
}
fn getUnusedTcpPort() !u16 { fn getUnusedTcpPort() !u16 {
const addr = try std.net.Address.parseIp("127.0.0.1", 0); const addr = try std.net.Address.parseIp("127.0.0.1", 0);
var s = std.net.StreamServer.init(.{}); var s = try addr.listen(.{});
defer s.deinit(); defer s.deinit();
try s.listen(addr);
return s.listen_address.in.getPort(); return s.listen_address.in.getPort();
} }
@ -225,16 +215,15 @@ pub fn main() !void {
defer _ = gpa_client.deinit(); defer _ = gpa_client.deinit();
server = Server.init(.{ .reuse_address = true });
const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable; const addr = std.net.Address.parseIp("127.0.0.1", 0) catch unreachable;
try server.listen(addr); var server = try addr.listen(.{ .reuse_address = true });
defer server.deinit();
const port = server.socket.listen_address.getPort(); const port = server.listen_address.getPort();
const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server}); const server_thread = try std.Thread.spawn(.{}, serverThread, .{&server});
var client = Client{ .allocator = calloc }; var client: http.Client = .{ .allocator = calloc };
errdefer client.deinit(); errdefer client.deinit();
// defer client.deinit(); handled below // defer client.deinit(); handled below
@ -691,6 +680,12 @@ pub fn main() !void {
client.deinit(); client.deinit();
killServer(server.socket.listen_address); {
handle_new_requests = false;
const conn = std.net.tcpConnectToAddress(server.listen_address) catch return;
conn.close();
}
server_thread.join(); server_thread.join();
} }