Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into update-dotnet-8
Browse files Browse the repository at this point in the history
  • Loading branch information
raymens committed Feb 19, 2024
2 parents 8264f1a + d40a89d commit b26e2e0
Show file tree
Hide file tree
Showing 14 changed files with 720 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// of this distribution or at http://www.apache.org/licenses/LICENSE-2.0

using Core.Application.Pipelines;
using Core.Domain;
using FluentValidation;
using MediatR;

Expand All @@ -19,6 +20,8 @@ public static IServiceCollection AddApplication(this IServiceCollection services
services.AddValidatorsFromAssembly(Application.AssemblyReference.Assembly);
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(ApplicationInjection).Assembly));

services.AddSingleton<IDateTimeProvider, SystemDateTimeProvider>();

return services;
}

Expand Down
2 changes: 2 additions & 0 deletions backend/core/src/Core.API/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
app.UseRouting();
app.UseAuthorization();

app.UseMiddleware<PublicKeyLinkedMiddleware>();

app.MapControllers();


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2023 Quantoz Technology B.V. and contributors. Licensed
// under the Apache License, Version 2.0. See the NOTICE file at the root
// of this distribution or at http://www.apache.org/licenses/LICENSE-2.0

using System.Net;
using System.Security.Claims;
using System.Security.Principal;
using Core.Domain.Exceptions;
using Core.Domain.Repositories;
using Core.Presentation.Models;

namespace Core.API.ResponseHandling;

/// <summary>
/// Verifies if the public key supplied in the header is linked to the user.
/// The user is determined by the NameIdentifier claim (CustomerCode) in the request.
/// </summary>
public class PublicKeyLinkedMiddleware
{
private readonly ICustomerDeviceRepository _customerDeviceRepository;
private readonly RequestDelegate _next;
private readonly ILogger<PublicKeyLinkedMiddleware> _logger;

public PublicKeyLinkedMiddleware(
ICustomerDeviceRepository customerDeviceRepository,
RequestDelegate next,
ILogger<PublicKeyLinkedMiddleware> logger)
{
_customerDeviceRepository = customerDeviceRepository;
_next = next;
_logger = logger;
}

public async Task Invoke(HttpContext context)
{
// skip the middleware for specific endpoints
if (!SkipEndpoint(context))
{
var userId = GetUserId(context.User);
var pubKey = context.Request.Headers["x-public-key"];

// get all public keys linked to the user
var customerDevices = await _customerDeviceRepository.GetAsync(userId, context.RequestAborted);

// if the user does not have any public keys or the public key is not linked to the user, return forbidden
if (customerDevices is null
|| customerDevices.PublicKeys.All(keys => keys.PublicKey != pubKey))
{
_logger.LogError("Public key not linked to user");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Unknown public-key", "x-public-key"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}
}

// safe to continue
await _next(context);
}

/// <summary>
/// Extract the user id from the claims
/// </summary>
private static string GetUserId(IPrincipal user)
{
if (user.Identity is not ClaimsIdentity identity)
{
throw new Exception($"Identity is not of type ClaimsIdentity");
}

var claim = identity.FindFirst(ClaimTypes.NameIdentifier);

return claim is null
? throw new Exception($"{ClaimTypes.NameIdentifier} not found in claims")
: claim.Value;
}

/// <summary>
/// Skip the middleware for specific endpoints
/// </summary>
private static bool SkipEndpoint(HttpContext context)
{
var endpoint = context.GetEndpoint();
var endpointName = endpoint?.Metadata.GetMetadata<EndpointNameMetadata>()?.EndpointName;

var excludeList = new[] { "DeviceAuthentication" };

return context.Request.Path.StartsWithSegments("/health")
|| excludeList.Contains(endpointName);
}

private static async Task WriteCustomErrors(HttpResponse httpResponse, CustomErrors customErrors, int statusCode)
{
httpResponse.StatusCode = statusCode;
httpResponse.ContentType = "application/json";

var response = CustomErrorsResponse.FromCustomErrors(customErrors);
var json = System.Text.Json.JsonSerializer.Serialize(response);
await httpResponse.WriteAsync(json);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,134 +2,149 @@
// under the Apache License, Version 2.0. See the NOTICE file at the root
// of this distribution or at http://www.apache.org/licenses/LICENSE-2.0

using Core.Domain;
using Core.Domain.Exceptions;
using Core.Presentation.Models;
using Newtonsoft.Json.Linq;
using NSec.Cryptography;
using System.Net;
using System.Text;
using static Core.Domain.Constants;

namespace Core.API.ResponseHandling
{
/// <summary>
/// Middleware to verify the signature of incoming requests.
/// It reads the x-signature, x-algorithm, x-public-key and x-timestamp headers from the request.
/// Using the supported algorithm it verifies the signature of the request.
///
/// A 30 second time difference is allowed between the timestamp in the request and the server time.
///
/// Checking if the public-key is valid is not done in this middleware.
/// </summary>
public class SignatureVerificationMiddleware
{
private readonly RequestDelegate _next;
private readonly IDateTimeProvider _dateTimeProvider;
private readonly ILogger<SignatureVerificationMiddleware> _logger;

public SignatureVerificationMiddleware(
RequestDelegate next,
IDateTimeProvider dateTimeProvider,
ILogger<SignatureVerificationMiddleware> logger)
{
_next = next;
_dateTimeProvider = dateTimeProvider;
_logger = logger;
}

private enum SignatureAlgorithmHeader
{
ED25519
}

public async Task Invoke(HttpContext context)
{
try
// Retrieve headers from the request
string? signatureHeader = context.Request.Headers["x-signature"];
string? algorithmHeader = context.Request.Headers["x-algorithm"];
string? publicKeyHeader = context.Request.Headers["x-public-key"];
string? timestampHeader = context.Request.Headers["x-timestamp"];

// Make sure the headers are present
if (!Enum.TryParse<SignatureAlgorithmHeader>(algorithmHeader, ignoreCase: true, out _))
{
// Retrieve headers from the request
string? signatureHeader = context.Request.Headers["x-signature"];
string? payloadHeader = context.Request.Headers["x-payload"];
string? publicKeyHeader = context.Request.Headers["x-public-key"];

// Retrieve method from the request
var method = context.Request.Method;

if (string.IsNullOrWhiteSpace(payloadHeader))
{
_logger.LogError("Missing payload header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-payload"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(signatureHeader))
{
_logger.LogError("Missing signature header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(publicKeyHeader))
{
_logger.LogError("Missing publicKey header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-public-key"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

byte[] payloadBytes = Convert.FromBase64String(payloadHeader);
var payloadString = Encoding.UTF8.GetString(payloadBytes);

JObject payloadJson = JObject.Parse(payloadString);

byte[] publicKeyBytes = Convert.FromBase64String(publicKeyHeader);

// Get the current Unix UTC timestamp (rounded to 30 seconds)
long currentTimestamp = (long)(DateTime.UtcNow.Subtract(new DateTime(1970, 1, 1))).TotalSeconds;
currentTimestamp = (currentTimestamp / 30) * 30; // Round to the nearest 30 seconds

// Decode the signature header from Base64
byte[]? signatureBytes = Convert.FromBase64String(signatureHeader);

long timestamp = 0;

// Check if the "timestamp" property is present
if (payloadJson.TryGetValue(SignaturePayload.Timestamp, out var timestampToken))
{
// Extract the timestamp value
timestamp = (long)timestampToken;
}
else
{
_logger.LogError("Missing timestamp in header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

bool isCurrentTime = timestamp == currentTimestamp;

long allowedDifference = 30; // 30 seconds
bool isWithin30Seconds = Math.Abs(currentTimestamp - timestamp) <= allowedDifference;

if (isCurrentTime || isWithin30Seconds)
{
if (VerifySignature(publicKeyBytes, payloadBytes, signatureBytes))
{
await _next(context); // Signature is valid, continue with the request
}
else
{
_logger.LogError("Invalid signature");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid signature", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
}
}
else
{
_logger.LogError("Timestamp outdated");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid timestamp", "timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
}
_logger.LogError("Invalid algorithm header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid Header", "x-algorithm"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}
catch (CustomErrorsException ex)

if (string.IsNullOrWhiteSpace(signatureHeader))
{
_logger.LogError(ex, "Unknown exception thrown: {message}", ex.Message);
throw;
_logger.LogError("Missing signature header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}
catch (Exception ex)

if (string.IsNullOrWhiteSpace(publicKeyHeader))
{
_logger.LogError(ex, "Unknown exception thrown: {message}", ex.Message);
var customErrors = new CustomErrors(new CustomError("Forbidden", ex.Message, ex.Source!));
_logger.LogError("Missing publicKey header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-public-key"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

if (string.IsNullOrWhiteSpace(timestampHeader)
|| !long.TryParse(timestampHeader, out var timestampHeaderLong))
{
_logger.LogError("Missing timestamp header");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Missing Header", "x-timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

// Check if the timestamp is within the allowed time
if (!IsWithinAllowedTime(timestampHeaderLong))
{
_logger.LogError("Timestamp outdated");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid timestamp", "x-timestamp"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
return;
}

// TODO: Check if the public key is valid according to algorithm

var payloadSigningStream = await GetPayloadStream(context, timestampHeader);

// Parse the public key
var publicKeyBytes = Convert.FromBase64String(publicKeyHeader);

// Decode the signature header from Base64
var signatureBytes = Convert.FromBase64String(signatureHeader);

if (VerifySignature(publicKeyBytes, payloadSigningStream.ToArray(), signatureBytes))
{
// Signature is valid, continue with the request
await _next(context);
}
else
{
_logger.LogError("Invalid signature");
var customErrors = new CustomErrors(new CustomError("Forbidden", "Invalid signature", "x-signature"));
await WriteCustomErrors(context.Response, customErrors, (int)HttpStatusCode.Forbidden);
}
}

private static async Task<MemoryStream> GetPayloadStream(HttpContext context, string timestampHeader)
{
// Leave the body open so the next middleware can read it.
context.Request.EnableBuffering();

// Set-up the payload stream to verify the signature
var payloadSigningStream = new MemoryStream();

// Copy the timestamp to the payload stream
await payloadSigningStream.WriteAsync(Encoding.UTF8.GetBytes(timestampHeader));

// Copy the request body to the payload stream
await context.Request.Body.CopyToAsync(payloadSigningStream);

// Reset the request body stream position so the next middleware can read it
context.Request.Body.Position = 0;

return payloadSigningStream;
}

private bool IsWithinAllowedTime(long timestampHeaderLong)
{
var suppliedDateTime = DateTimeOffset.FromUnixTimeSeconds(timestampHeaderLong);
var dateDiff = _dateTimeProvider.UtcNow - suppliedDateTime;
long allowedDifference = 30; // 30 seconds
return Math.Abs(dateDiff.TotalSeconds) <= allowedDifference;
}

private static async Task WriteCustomErrors(HttpResponse httpResponse, CustomErrors customErrors, int statusCode)
private static async Task WriteCustomErrors(HttpResponse httpResponse, CustomErrors customErrors,
int statusCode)
{
httpResponse.StatusCode = statusCode;
httpResponse.ContentType = "application/json";
Expand Down Expand Up @@ -161,4 +176,4 @@ public static void ConfigureSignatureVerificationMiddleware(this IApplicationBui
app.UseMiddleware<SignatureVerificationMiddleware>();
}
}
}
}
Loading

0 comments on commit b26e2e0

Please sign in to comment.