diff --git a/Backend/Remora.Discord.Gateway/DiscordGatewayClient.cs b/Backend/Remora.Discord.Gateway/DiscordGatewayClient.cs index d7b2506184..15ddcd1d20 100644 --- a/Backend/Remora.Discord.Gateway/DiscordGatewayClient.cs +++ b/Backend/Remora.Discord.Gateway/DiscordGatewayClient.cs @@ -566,6 +566,18 @@ private async Task RunConnectionIterationAsync(CancellationToken stopReq if (receiveHello.Entity is not IPayload hello) { + if (receiveHello.Entity is IPayload) + { + // Discord may spit out a reconnect if the node is we're connecting while the gateway node is + // shutting down, but before the node is labeled as unavailable. + return new GatewayError + ( + "The gateway requested a reconnect.", + false, + false + ); + } + // Not receiving a hello is a non-recoverable error return new GatewayError ( diff --git a/Tests/Remora.Discord.Gateway.Tests/Tests/DiscordGatewayClientTests.cs b/Tests/Remora.Discord.Gateway.Tests/Tests/DiscordGatewayClientTests.cs index c0e0e6ca0a..4a35400944 100644 --- a/Tests/Remora.Discord.Gateway.Tests/Tests/DiscordGatewayClientTests.cs +++ b/Tests/Remora.Discord.Gateway.Tests/Tests/DiscordGatewayClientTests.cs @@ -711,6 +711,76 @@ public async Task CanReconnectAfterExceptionAsync() ResultAssert.Successful(runResult); } + /// + /// Tests whether the client can reconnect after being sent a Reconnect instead of Hello. + /// + /// A representing the asynchronous unit test. + [Fact] + public async Task CanReconnectAfterReconnectInsteadOfHelloAsync() + { + var tokenSource = new CancellationTokenSource(); + var transportMock = new MockedTransportServiceBuilder(_testOutput) + .WithTimeout(TimeSpan.FromSeconds(30)) + .Sequence + ( + s => s + .ExpectConnection + ( + new Uri($"wss://gateway.discord.gg/?v={(int)DiscordAPIVersion.V10}&encoding=json") + ) + .Send(new Reconnect()) + .ExpectDisconnect() + .ExpectConnection + ( + new Uri($"wss://gateway.discord.gg/?v={(int)DiscordAPIVersion.V10}&encoding=json") + ) + .Send(new Hello(TimeSpan.FromMilliseconds(200))) + .Expect + ( + i => + { + Assert.Equal(Constants.MockToken, i?.Token); + return true; + } + ) + .Send + ( + new Ready + ( + 8, + Constants.BotUser, + new List(), + Constants.MockSessionID, + Constants.MockResumeGatewayUrl, + default, + new PartialApplication() + ) + ) + ) + .Continuously + ( + c => c + .Expect() + .Send() + ) + .Finish(tokenSource) + .Build(); + + var transportMockDescriptor = ServiceDescriptor.Singleton(typeof(IPayloadTransportService), transportMock); + + var services = new ServiceCollection() + .AddDiscordGateway(_ => Constants.MockToken) + .Replace(transportMockDescriptor) + .Replace(CreateMockedGatewayAPI()) + .AddSingleton() + .BuildServiceProvider(true); + + var client = services.GetRequiredService(); + var runResult = await client.RunAsync(tokenSource.Token); + + ResultAssert.Successful(runResult); + } + private static ServiceDescriptor CreateMockedGatewayAPI() { var gatewayAPIMock = new Mock();