Skip to content

Commit

Permalink
fix(tcp_chain): ensure accepting before connect
Browse files Browse the repository at this point in the history
  • Loading branch information
mookums committed Mar 6, 2025
1 parent 56c6e27 commit 7f211e1
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
14 changes: 7 additions & 7 deletions src/aio/apis/poll.zig
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ const RecvError = @import("../completion.zig").RecvError;
const SendResult = @import("../completion.zig").SendResult;
const SendError = @import("../completion.zig").SendError;

const TimerPair = struct {
milliseconds: usize,
task: usize,
};
const TimerPair = struct { milliseconds: usize, task: usize };

const TimerQueue = std.PriorityQueue(TimerPair, void, struct {
fn compare(_: void, a: TimerPair, b: TimerPair) std.math.Order {
Expand Down Expand Up @@ -261,7 +258,6 @@ pub const AsyncPoll = struct {
@panic("failed to get job from fd!");
};

// TODO: add job_complete that allows us to leave jobs that return WouldBlock
poll_result -= 1;
_ = poll.fd_list.swapRemove(i);
assert(poll.fd_job_map.remove(pollfd.fd));
Expand All @@ -277,10 +273,14 @@ pub const AsyncPoll = struct {

// requeue the wake request
if (comptime builtin.os.tag == .windows) {
try poll.fd_list.append(.{ .fd = @ptrCast(poll.wake_pipe[0]), .events = std.posix.POLL.IN, .revents = 0 });
try poll.fd_list.append(
.{ .fd = @ptrCast(poll.wake_pipe[0]), .events = std.posix.POLL.IN, .revents = 0 },
);
try poll.fd_job_map.put(@ptrCast(poll.wake_pipe[0]), .{ .index = 0, .type = .wake, .task = 0 });
} else {
try poll.fd_list.append(.{ .fd = poll.wake_pipe[0], .events = std.posix.POLL.IN, .revents = 0 });
try poll.fd_list.append(
.{ .fd = poll.wake_pipe[0], .events = std.posix.POLL.IN, .revents = 0 },
);
try poll.fd_job_map.put(poll.wake_pipe[0], .{ .index = 0, .type = .wake, .task = 0 });
}

Expand Down
7 changes: 5 additions & 2 deletions test/e2e/second.zig
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ pub fn start_frame(rt: *Runtime, shared_params: *const SharedParams) !void {
server_chain_ptr.* = try TcpServerChain.init(rt.allocator, chain, 4096);
client_chain_ptr.* = try server_chain_ptr.derive_client_chain();

const accepting = try rt.allocator.create(bool);
accepting.* = false;

try rt.spawn(
.{ client_chain_ptr, rt, &tcp_client_chain_count, port },
.{ client_chain_ptr, rt, &tcp_client_chain_count, port, accepting },
TcpClientChain.chain_frame,
STACK_SIZE,
);
try rt.spawn(
.{ server_chain_ptr, rt, &tcp_server_chain_count, socket },
.{ server_chain_ptr, rt, &tcp_server_chain_count, socket, accepting },
TcpServerChain.chain_frame,
STACK_SIZE,
);
Expand Down
13 changes: 10 additions & 3 deletions test/e2e/tcp_chain.zig
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const assert = std.debug.assert;
const Runtime = @import("tardy").Runtime;

const Socket = @import("tardy").Socket;
const Frame = @import("tardy").Frame;

const OpenFileResult = @import("tardy").OpenFileResult;
const ReadResult = @import("tardy").ReadResult;
Expand Down Expand Up @@ -112,15 +113,17 @@ pub const TcpServerChain = struct {
defer self.allocator.free(self.buffer);
}

pub fn chain_frame(chain: *TcpServerChain, rt: *Runtime, counter: *usize, server_socket: Socket) !void {
pub fn chain_frame(chain: *TcpServerChain, rt: *Runtime, counter: *usize, server_socket: Socket, accepting: *bool) !void {
defer rt.allocator.destroy(chain);
defer chain.deinit();
errdefer unreachable;

chain: while (chain.index < chain.steps.len) : (chain.index += 1) {
switch (chain.steps[chain.index]) {
.accept => {
accepting.* = true;
const socket = try server_socket.accept(rt);
accepting.* = false;
chain.socket = socket;
},
.recv => {
Expand All @@ -143,6 +146,7 @@ pub const TcpServerChain = struct {
if (counter.* == 0) {
log.debug("closing main accept socket", .{});
try server_socket.close(rt);
rt.allocator.destroy(accepting);
}
}
};
Expand All @@ -165,7 +169,7 @@ pub const TcpClientChain = struct {
defer self.allocator.free(self.buffer);
}

pub fn chain_frame(chain: *TcpClientChain, rt: *Runtime, counter: *usize, port: u16) !void {
pub fn chain_frame(chain: *TcpClientChain, rt: *Runtime, counter: *usize, port: u16, accepting: *bool) !void {
defer rt.allocator.destroy(chain);
defer chain.deinit();
errdefer unreachable;
Expand All @@ -174,7 +178,10 @@ pub const TcpClientChain = struct {

chain: while (chain.index < chain.steps.len) : (chain.index += 1) {
switch (chain.steps[chain.index]) {
.connect => _ = try socket.connect(rt),
.connect => {
while (!accepting.*) Frame.yield();
_ = try socket.connect(rt);
},
.recv => {
const length = socket.recv(rt, chain.buffer) catch |e| switch (e) {
error.Closed => break :chain,
Expand Down

0 comments on commit 7f211e1

Please sign in to comment.