Skip to content

Commit

Permalink
Add middleware to verify pubkey is linked to customer
Browse files Browse the repository at this point in the history
Add various tests
  • Loading branch information
raymens committed Feb 16, 2024
1 parent b59d156 commit 14a366f
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 6 deletions.
2 changes: 2 additions & 0 deletions backend/core/src/Core.API/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
app.UseRouting();
app.UseAuthorization();

app.UseMiddleware<PublicKeyLinkedMiddleware>();

app.MapControllers();


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
using Core.Domain.Repositories;
using System.Net;
using System.Security.Claims;
using System.Security.Principal;

namespace Core.API.ResponseHandling;

/// <summary>
/// Verifies if the public key supplied in the header is linked to the user.
/// The user is determined by the NameIdentifier claim (CustomerCode) in the request.
/// </summary>
public class PublicKeyLinkedMiddleware
{
private readonly ICustomerDeviceRepository _customerDeviceRepository;
private readonly RequestDelegate _next;

public PublicKeyLinkedMiddleware(
ICustomerDeviceRepository customerDeviceRepository,
RequestDelegate next)
{
_customerDeviceRepository = customerDeviceRepository;
_next = next;
}

public async Task Invoke(HttpContext context)
{
// skip the middleware for specific endpoints
if (!SkipEndpoint(context))
{
var userId = GetUserId(context.User);
var pubKey = context.Request.Headers["x-public-key"];

// get all public keys linked to the user
var customerDevices = await _customerDeviceRepository.GetAsync(userId, context.RequestAborted);

// if the user does not have any public keys or the public key is not linked to the user, return forbidden
if (customerDevices is null
|| customerDevices.PublicKeys.All(keys => keys.PublicKey != pubKey))
{
// TODO: make this a custom error

context.Response.StatusCode = (int)HttpStatusCode.Forbidden;
await context.Response.WriteAsync("Public key not linked to user");
return;
}
}

// safe to continue
await _next(context);
}

/// <summary>
/// Extract the user id from the claims
/// </summary>
private static string GetUserId(IPrincipal user)
{
if (user.Identity is not ClaimsIdentity identity)
{
throw new Exception($"Identity is not of type ClaimsIdentity");
}

var claim = identity.FindFirst(ClaimTypes.NameIdentifier);

return claim is null
? throw new Exception($"{ClaimTypes.NameIdentifier} not found in claims")
: claim.Value;
}

/// <summary>
/// Skip the middleware for specific endpoints
/// </summary>
private static bool SkipEndpoint(HttpContext context)
{
var endpoint = context.GetEndpoint();
var endpointName = endpoint?.Metadata.GetMetadata<EndpointNameMetadata>()?.EndpointName;

var excludeList = new[] { "DeviceAuthentication" };

return context.Request.Path.StartsWithSegments("/health")
|| excludeList.Contains(endpointName);
}
}
51 changes: 51 additions & 0 deletions backend/core/tests/Core.APITests/FakeCustomerDeviceRepository.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Core.Domain.Abstractions;
using Core.Domain.Entities.CustomerAggregate;
using Core.Domain.Repositories;

namespace Core.APITests;

public class FakeCustomerDeviceRepository : ICustomerDeviceRepository
{
public void Add(CustomerOTPKeyStore entity)
{
throw new NotImplementedException();
}

public Task<CustomerOTPKeyStore> FindAsync(int id, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

public Task<CustomerOTPKeyStore?> GetAsync(string customerCode, CancellationToken cancellationToken = default)
{
var store = new CustomerOTPKeyStore
{
CustomerCode = customerCode,
PublicKeys = new[] { new CustomerDevicePublicKeys { PublicKey = "VALID-PUBKEY" } }
};

return Task.FromResult(store)!;
}

public Task<bool> HasOtpKeyAsync(string customerCode, CancellationToken cancellationToken = default)
{
return Task.FromResult(false);
}

public Task<bool> HasPublicKeyAsync(string customerCode, string publicKey,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

public Task<List<CustomerOTPKeyStore>> ListAsync(ISpecification<CustomerOTPKeyStore> specification,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}

public void Update(CustomerOTPKeyStore entity)
{
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
// under the Apache License, Version 2.0. See the NOTICE file at the root
// of this distribution or at http://www.apache.org/licenses/LICENSE-2.0

using Core.API.ResponseHandling;
using Core.API.ResponseHandling;
using Core.Domain;
using Core.Domain.Exceptions;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using System.Net;

Expand All @@ -33,10 +32,7 @@ public async Task ExistingKey_Returns_Conflict()
{
app.ConfigureCustomExceptionMiddleware();

app.Run(context =>
{
throw new CustomErrorsException(DomainErrorCode.ExistingKeyError.ToString(), "TEST", "Verification needed.");
});
app.Run(context => throw new CustomErrorsException(DomainErrorCode.ExistingKeyError.ToString(), "TEST", "Verification needed."));
});
})
.StartAsync();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright 2023 Quantoz Technology B.V. and contributors. Licensed
// under the Apache License, Version 2.0. See the NOTICE file at the root
// of this distribution or at http://www.apache.org/licenses/LICENSE-2.0

using Core.API.ResponseHandling;
using Core.Domain.Repositories;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using System.Net;
using System.Net.Http.Headers;
using Microsoft.AspNetCore.Authentication;

namespace Core.APITests.ResponseHandlingTests;

[TestClass]
public class PublicKeyLinkedMiddlewareTests : IDisposable
{
private IHost host = default!;
private HttpClient client = default!;

[TestInitialize]
public async Task Init()
{
host = await new HostBuilder()
.ConfigureWebHost(webBuilder =>
{
webBuilder
.UseTestServer()
.ConfigureServices(services =>
{
services.AddRouting();
services.AddTransient<ICustomerDeviceRepository, FakeCustomerDeviceRepository>();

// add claims to user identity - this is the user that will be authenticated
services.AddAuthentication("Test")
.AddScheme<AuthenticationSchemeOptions, TestAuthHandler>("Test", op => { });
services.AddAuthorization();
})
.Configure(app =>
{
app.UseRouting();
app.UseAuthentication();

// add the middleware after authentication to have access to the user object and its claims
app.UseMiddleware<PublicKeyLinkedMiddleware>();
app.UseAuthorization();

app.UseEndpoints(endpoints =>
{
// map an endpoint that verifies the public key with the authenticated user
endpoints.Map("/verify", async context =>
{
await context.Response.WriteAsync($"Hello World {context.User.Identity.Name}!");
})
.RequireAuthorization()
.WithName("NotInTheIgnoreList");

// map an endpoint that ignores the public key verification (hardcoded list) through the name
endpoints.Map("/ignore", async context =>
{
await context.Response.WriteAsync($"Hello World {context.User.Identity.Name}!");
})
.RequireAuthorization()
.WithName("DeviceAuthentication");
});
});
})
.StartAsync();

client = host.GetTestClient();
client.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Test");
}

void IDisposable.Dispose()
{
host.Dispose();
client.Dispose();
}

[TestMethod]
public async Task Verify_Returns_Forbidden()
{
var response = await client.GetAsync("/verify");

var responseString = await response.Content.ReadAsStringAsync();
Assert.AreEqual("Public key not linked to user",
responseString);
Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode);
}

[TestMethod]
public async Task Verify_Returns_Ok()
{
var request = new HttpRequestMessage(HttpMethod.Get, "/verify");
request.Headers.Add("x-public-key", "VALID-PUBKEY");
var response = await client.GetAsync("/verify");

var responseString = await response.Content.ReadAsStringAsync();
Assert.AreEqual("Public key not linked to user",
responseString);
Assert.AreEqual(HttpStatusCode.Forbidden, response.StatusCode);
}

[TestMethod]
public async Task IgnoreEndpoint_Returns_Ok()
{
// send request
var response = await client.GetAsync("/ignore");

var responseString = await response.Content.ReadAsStringAsync();
Assert.AreEqual($"Hello World TestUser!", responseString);
Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
}
}
31 changes: 31 additions & 0 deletions backend/core/tests/Core.APITests/TestAuthHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System.Security.Claims;
using System.Text.Encodings.Web;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Core.APITests.ResponseHandlingTests;

public sealed class TestAuthHandler : AuthenticationHandler<AuthenticationSchemeOptions>
{
public TestAuthHandler(IOptionsMonitor<AuthenticationSchemeOptions> options, ILoggerFactory logger,
UrlEncoder encoder, ISystemClock clock) : base(options, logger, encoder, clock)
{
}

protected override Task<AuthenticateResult> HandleAuthenticateAsync()
{
var claims = new List<Claim>
{
new(ClaimTypes.Name, "TestUser"),
new(ClaimTypes.NameIdentifier, "TestUser")
};
var identity = new ClaimsIdentity(claims, "Test");
var principal = new ClaimsPrincipal(identity);
var ticket = new AuthenticationTicket(principal, "Test");

var result = AuthenticateResult.Success(ticket);

return Task.FromResult(result);
}
}

0 comments on commit 14a366f

Please sign in to comment.