diff --git a/backend/core/src/Core.API/Program.cs b/backend/core/src/Core.API/Program.cs index a8a30e5..76ea524 100644 --- a/backend/core/src/Core.API/Program.cs +++ b/backend/core/src/Core.API/Program.cs @@ -47,6 +47,8 @@ app.UseRouting(); app.UseAuthorization(); +app.UseMiddleware(); + app.MapControllers(); diff --git a/backend/core/src/Core.API/ResponseHandling/PublicKeyLinkedMiddleware.cs b/backend/core/src/Core.API/ResponseHandling/PublicKeyLinkedMiddleware.cs new file mode 100644 index 0000000..d327f07 --- /dev/null +++ b/backend/core/src/Core.API/ResponseHandling/PublicKeyLinkedMiddleware.cs @@ -0,0 +1,82 @@ +using Core.Domain.Repositories; +using System.Net; +using System.Security.Claims; +using System.Security.Principal; + +namespace Core.API.ResponseHandling; + +/// +/// Verifies if the public key supplied in the header is linked to the user. +/// The user is determined by the NameIdentifier claim (CustomerCode) in the request. +/// +public class PublicKeyLinkedMiddleware +{ + private readonly ICustomerDeviceRepository _customerDeviceRepository; + private readonly RequestDelegate _next; + + public PublicKeyLinkedMiddleware( + ICustomerDeviceRepository customerDeviceRepository, + RequestDelegate next) + { + _customerDeviceRepository = customerDeviceRepository; + _next = next; + } + + public async Task Invoke(HttpContext context) + { + // skip the middleware for specific endpoints + if (!SkipEndpoint(context)) + { + var userId = GetUserId(context.User); + var pubKey = context.Request.Headers["x-public-key"]; + + // get all public keys linked to the user + var customerDevices = await _customerDeviceRepository.GetAsync(userId, context.RequestAborted); + + // if the user does not have any public keys or the public key is not linked to the user, return forbidden + if (customerDevices is null + || customerDevices.PublicKeys.All(keys => keys.PublicKey != pubKey)) + { + // TODO: make this a custom error + + context.Response.StatusCode = (int)HttpStatusCode.Forbidden; + await context.Response.WriteAsync("Public key not linked to user"); + return; + } + } + + // safe to continue + await _next(context); + } + + /// + /// Extract the user id from the claims + /// + private static string GetUserId(IPrincipal user) + { + if (user.Identity is not ClaimsIdentity identity) + { + throw new Exception($"Identity is not of type ClaimsIdentity"); + } + + var claim = identity.FindFirst(ClaimTypes.NameIdentifier); + + return claim is null + ? throw new Exception($"{ClaimTypes.NameIdentifier} not found in claims") + : claim.Value; + } + + /// + /// Skip the middleware for specific endpoints + /// + private static bool SkipEndpoint(HttpContext context) + { + var endpoint = context.GetEndpoint(); + var endpointName = endpoint?.Metadata.GetMetadata()?.EndpointName; + + var excludeList = new[] { "DeviceAuthentication" }; + + return context.Request.Path.StartsWithSegments("/health") + || excludeList.Contains(endpointName); + } +} \ No newline at end of file diff --git a/backend/core/tests/Core.APITests/FakeCustomerDeviceRepository.cs b/backend/core/tests/Core.APITests/FakeCustomerDeviceRepository.cs new file mode 100644 index 0000000..7179ac0 --- /dev/null +++ b/backend/core/tests/Core.APITests/FakeCustomerDeviceRepository.cs @@ -0,0 +1,51 @@ +using Core.Domain.Abstractions; +using Core.Domain.Entities.CustomerAggregate; +using Core.Domain.Repositories; + +namespace Core.APITests; + +public class FakeCustomerDeviceRepository : ICustomerDeviceRepository +{ + public void Add(CustomerOTPKeyStore entity) + { + throw new NotImplementedException(); + } + + public Task FindAsync(int id, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task GetAsync(string customerCode, CancellationToken cancellationToken = default) + { + var store = new CustomerOTPKeyStore + { + CustomerCode = customerCode, + PublicKeys = new[] { new CustomerDevicePublicKeys { PublicKey = "VALID-PUBKEY" } } + }; + + return Task.FromResult(store)!; + } + + public Task HasOtpKeyAsync(string customerCode, CancellationToken cancellationToken = default) + { + return Task.FromResult(false); + } + + public Task HasPublicKeyAsync(string customerCode, string publicKey, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Task> ListAsync(ISpecification specification, + CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public void Update(CustomerOTPKeyStore entity) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/backend/core/tests/Core.APITests/ResponseHandlingTests/ExceptionMiddlewareTests.cs b/backend/core/tests/Core.APITests/ResponseHandlingTests/ExceptionMiddlewareTests.cs index fcbfcce..3223e9a 100644 --- a/backend/core/tests/Core.APITests/ResponseHandlingTests/ExceptionMiddlewareTests.cs +++ b/backend/core/tests/Core.APITests/ResponseHandlingTests/ExceptionMiddlewareTests.cs @@ -2,13 +2,12 @@ // under the Apache License, Version 2.0. See the NOTICE file at the root // of this distribution or at http://www.apache.org/licenses/LICENSE-2.0 -using Core.API.ResponseHandling; +using Core.API.ResponseHandling; using Core.Domain; using Core.Domain.Exceptions; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.TestHost; -using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using System.Net; @@ -33,10 +32,7 @@ public async Task ExistingKey_Returns_Conflict() { app.ConfigureCustomExceptionMiddleware(); - app.Run(context => - { - throw new CustomErrorsException(DomainErrorCode.ExistingKeyError.ToString(), "TEST", "Verification needed."); - }); + app.Run(context => throw new CustomErrorsException(DomainErrorCode.ExistingKeyError.ToString(), "TEST", "Verification needed.")); }); }) .StartAsync(); diff --git a/backend/core/tests/Core.APITests/ResponseHandlingTests/PublicKeyLinkedMiddlewareTests.cs b/backend/core/tests/Core.APITests/ResponseHandlingTests/PublicKeyLinkedMiddlewareTests.cs new file mode 100644 index 0000000..c13f11f --- /dev/null +++ b/backend/core/tests/Core.APITests/ResponseHandlingTests/PublicKeyLinkedMiddlewareTests.cs @@ -0,0 +1,118 @@ +// Copyright 2023 Quantoz Technology B.V. and contributors. Licensed +// under the Apache License, Version 2.0. See the NOTICE file at the root +// of this distribution or at http://www.apache.org/licenses/LICENSE-2.0 + +using Core.API.ResponseHandling; +using Core.Domain.Repositories; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using System.Net; +using System.Net.Http.Headers; +using Microsoft.AspNetCore.Authentication; + +namespace Core.APITests.ResponseHandlingTests; + +[TestClass] +public class PublicKeyLinkedMiddlewareTests : IDisposable +{ + private IHost host = default!; + private HttpClient client = default!; + + [TestInitialize] + public async Task Init() + { + host = await new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + services.AddTransient(); + + // add claims to user identity - this is the user that will be authenticated + services.AddAuthentication("Test") + .AddScheme("Test", op => { }); + services.AddAuthorization(); + }) + .Configure(app => + { + app.UseRouting(); + app.UseAuthentication(); + + // add the middleware after authentication to have access to the user object and its claims + app.UseMiddleware(); + app.UseAuthorization(); + + app.UseEndpoints(endpoints => + { + // map an endpoint that verifies the public key with the authenticated user + endpoints.Map("/verify", async context => + { + await context.Response.WriteAsync($"Hello World {context.User.Identity.Name}!"); + }) + .RequireAuthorization() + .WithName("NotInTheIgnoreList"); + + // map an endpoint that ignores the public key verification (hardcoded list) through the name + endpoints.Map("/ignore", async context => + { + await context.Response.WriteAsync($"Hello World {context.User.Identity.Name}!"); + }) + .RequireAuthorization() + .WithName("DeviceAuthentication"); + }); + }); + }) + .StartAsync(); + + client = host.GetTestClient(); + client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Test"); + } + + void IDisposable.Dispose() + { + host.Dispose(); + client.Dispose(); + } + + [TestMethod] + public async Task Verify_Returns_Forbidden() + { + var response = await client.GetAsync("/verify"); + + var responseString = await response.Content.ReadAsStringAsync(); + Assert.AreEqual("Public key not linked to user", + responseString); + Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode); + } + + [TestMethod] + public async Task Verify_Returns_Ok() + { + var request = new HttpRequestMessage(HttpMethod.Get, "/verify"); + request.Headers.Add("x-public-key", "VALID-PUBKEY"); + var response = await client.GetAsync("/verify"); + + var responseString = await response.Content.ReadAsStringAsync(); + Assert.AreEqual("Public key not linked to user", + responseString); + Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode); + } + + [TestMethod] + public async Task IgnoreEndpoint_Returns_Ok() + { + // send request + var response = await client.GetAsync("/ignore"); + + var responseString = await response.Content.ReadAsStringAsync(); + Assert.AreEqual($"Hello World TestUser!", responseString); + Assert.AreEqual(HttpStatusCode.OK, response.StatusCode); + } +} \ No newline at end of file diff --git a/backend/core/tests/Core.APITests/TestAuthHandler.cs b/backend/core/tests/Core.APITests/TestAuthHandler.cs new file mode 100644 index 0000000..d70e467 --- /dev/null +++ b/backend/core/tests/Core.APITests/TestAuthHandler.cs @@ -0,0 +1,31 @@ +using System.Security.Claims; +using System.Text.Encodings.Web; +using Microsoft.AspNetCore.Authentication; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Core.APITests.ResponseHandlingTests; + +public sealed class TestAuthHandler : AuthenticationHandler +{ + public TestAuthHandler(IOptionsMonitor options, ILoggerFactory logger, + UrlEncoder encoder, ISystemClock clock) : base(options, logger, encoder, clock) + { + } + + protected override Task HandleAuthenticateAsync() + { + var claims = new List + { + new(ClaimTypes.Name, "TestUser"), + new(ClaimTypes.NameIdentifier, "TestUser") + }; + var identity = new ClaimsIdentity(claims, "Test"); + var principal = new ClaimsPrincipal(identity); + var ticket = new AuthenticationTicket(principal, "Test"); + + var result = AuthenticateResult.Success(ticket); + + return Task.FromResult(result); + } +} \ No newline at end of file