diff --git a/Tests/Services/BanManagerTest.cs b/Tests/Services/BanManagerTest.cs index 42b38a1..5d7ec03 100644 --- a/Tests/Services/BanManagerTest.cs +++ b/Tests/Services/BanManagerTest.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; using System.Linq; using System.Net; -using System.Threading.Tasks; +using System.Threading; using Fail2Ban4Win.Config; using Fail2Ban4Win.Facades; using Fail2Ban4Win.Services; @@ -42,8 +42,8 @@ public class BanManagerTest: IDisposable { neverBanSubnets = new[] { IPNetwork.Parse("73.202.12.148/32") } }; - private readonly IFirewallWASRulesCollection firewallRules = new FakeFirewallRulesCollection(); - private readonly FirewallFacade firewallFacade = A.Fake(); + private readonly FakeFirewallRulesCollection firewallRules = new(); + private readonly FirewallFacade firewallFacade = A.Fake(); public BanManagerTest(ITestOutputHelper testOutput) { this.testOutput = testOutput; @@ -166,12 +166,15 @@ public void dontBanInDryRunMode() { } [Fact] - public async Task deleteExistingRulesOnStartup() { + public void deleteExistingRulesOnStartup() { banManager.Dispose(); firewallRules.Add(new FirewallWASRule("deleteme1", FirewallAction.Block, FirewallDirection.Inbound, FirewallProfiles.Public) { Grouping = "Fail2Ban4Win" }); firewallRules.Add(new FirewallWASRule("deleteme2", FirewallAction.Block, FirewallDirection.Inbound, FirewallProfiles.Public) { Grouping = "Fail2Ban4Win" }); + CountdownEvent rulesRemoved = new(firewallRules.Count); + firewallRules.ruleRemoved += (_, _) => rulesRemoved.Signal(); + Assert.NotEmpty(firewallRules); BanManagerImpl manager = new(eventLogListener, configuration, firewallFacade); @@ -179,7 +182,7 @@ public async Task deleteExistingRulesOnStartup() { Assert.NotEmpty(firewallRules); //deletion runs asynchronously to speed up startup - await Task.Delay(100); + rulesRemoved.Wait(TimeSpan.FromSeconds(10)); Assert.Empty(firewallRules); @@ -187,12 +190,15 @@ public async Task deleteExistingRulesOnStartup() { } [Fact] - public async Task unbanAfterBanExpired() { - IEnumerable sourceAddresses = new[] { + public void unbanAfterBanExpired() { + ICollection sourceAddresses = new[] { IPAddress.Parse("198.51.100.1"), IPAddress.Parse("101.206.243.0") }; + CountdownEvent rulesRemoved = new(sourceAddresses.Count); + firewallRules.ruleRemoved += (_, _) => rulesRemoved.Signal(); + foreach (IPAddress sourceAddress in sourceAddresses) { for (int i = 0; i < MAX_ALLOWED_FAILURES + 1; i++) { eventLogListener.failure += Raise.With(null, sourceAddress); @@ -201,26 +207,31 @@ public async Task unbanAfterBanExpired() { Assert.NotEmpty(firewallRules); - await Task.Delay((int) configuration.banPeriod.TotalMilliseconds * 4); + rulesRemoved.Wait(TimeSpan.FromSeconds(10)); testOutput.WriteLine("banPeriod = {0}", configuration.banPeriod); Assert.Empty(firewallRules); } [Fact] - public async Task unbanCatchesAndLogsExceptions() { + public void unbanCatchesAndLogsExceptions() { + IPAddress sourceAddress = IPAddress.Parse("103.153.254.0"); + CountdownEvent rulesRemoved = new(1); + var throwingFirewallRules = A.Fake>(options => options.Wrapping(firewallRules)); - A.CallTo(() => throwingFirewallRules.Remove(A._)).Throws(); + A.CallTo(() => throwingFirewallRules.Remove(A._)).Throws(() => { + rulesRemoved.Signal(); + throw new InvalidOperationException("This is intentionally thrown as part of a test"); + }); A.CallTo(() => firewallFacade.Rules).Returns(throwingFirewallRules); - IPAddress sourceAddress = IPAddress.Parse("103.153.254.0"); for (int i = 0; i < MAX_ALLOWED_FAILURES + 1; i++) { eventLogListener.failure += Raise.With(null, sourceAddress); } Assert.NotEmpty(firewallRules); - await Task.Delay((int) configuration.banPeriod.TotalMilliseconds * 2); + rulesRemoved.Wait(TimeSpan.FromSeconds(10)); testOutput.WriteLine("banPeriod = {0}", configuration.banPeriod); Assert.NotEmpty(firewallRules); @@ -275,16 +286,22 @@ private class FakeFirewallRulesCollection: List, IFirewallWASRu private readonly object mutex = new(); public FirewallWASRule? this[string name] => this.FirstOrDefault(rule => rule.Name == name); + public event EventHandler? ruleRemoved; public bool Remove(string name) { + FirewallWASRule? ruleToRemove; + bool result; + lock (mutex) { - if (this[name] is { } ruleToRemove) { - Remove(ruleToRemove); - return true; - } else { - return false; - } + ruleToRemove = this[name]; + result = ruleToRemove is not null && Remove(ruleToRemove); + } + + if (result) { + ruleRemoved?.Invoke(this, ruleToRemove!); } + + return result; } void ICollection.Add(FirewallWASRule item) { @@ -300,9 +317,17 @@ void ICollection.Clear() { } bool ICollection.Remove(FirewallWASRule item) { + bool result; + lock (mutex) { - return Remove(item); + result = Remove(item); } + + if (result) { + ruleRemoved?.Invoke(this, item); + } + + return result; } }