diff --git a/src/Sdk.UnitTests/Cerbos/Sdk/UnitTests/CerbosClientTest.cs b/src/Sdk.UnitTests/Cerbos/Sdk/UnitTests/CerbosClientTest.cs index 59ad9d0..31b9076 100644 --- a/src/Sdk.UnitTests/Cerbos/Sdk/UnitTests/CerbosClientTest.cs +++ b/src/Sdk.UnitTests/Cerbos/Sdk/UnitTests/CerbosClientTest.cs @@ -21,6 +21,8 @@ public class CerbosClientTest private const string Tag = "dev"; private const string PathToPolicies = "./../../../res/policies"; private const string PathToConfig = "./../../../res/config"; + private readonly Grpc.Core.Metadata _metadata = new() { { "wibble", "wobble" } }; + private IContainer? _container; private CerbosClient? _client; @@ -45,8 +47,8 @@ public void OneTimeSetUp() Task.Run(async () => await _container.StartAsync()).Wait(); Thread.Sleep(3000); - _client = CerbosClientBuilder.ForTarget("http://127.0.0.1:3593").WithPlaintext().Build(); - _clientPlayground = CerbosClientBuilder.ForTarget(PlaygroundHost).WithPlaygroundInstance(PlaygroundInstanceId).Build(); + _client = CerbosClientBuilder.ForTarget("http://127.0.0.1:3593").WithMetadata(_metadata).WithPlaintext().Build(); + _clientPlayground = CerbosClientBuilder.ForTarget(PlaygroundHost).WithMetadata(_metadata).WithPlaygroundInstance(PlaygroundInstanceId).Build(); } [OneTimeTearDown] @@ -81,7 +83,7 @@ public void CheckWithoutJwt() ) .WithIncludeMeta(true); - var have = _client.CheckResources(request).Find("XX125"); + var have = _client.CheckResources(request, _metadata).Find("XX125"); Assert.That(have.IsAllowed("view:public"), Is.True); Assert.That(have.IsAllowed("approve"), Is.False); @@ -126,7 +128,7 @@ public void CheckWithJwt() .WithActions("defer") ); - var have = _client.CheckResources(request).Find("XX125"); + var have = _client.CheckResources(request, _metadata).Find("XX125"); Assert.That(have.IsAllowed("defer"), Is.True); } @@ -175,7 +177,7 @@ public void CheckMultiple() ); - var have = _client.CheckResources(request); + var have = _client.CheckResources(request, _metadata); var resourcexx125 = have.Find("XX125"); Assert.That(resourcexx125.IsAllowed("view:public"), Is.True); Assert.That(resourcexx125.IsAllowed("defer"), Is.True); @@ -213,7 +215,7 @@ public void PlanResources() .WithAction("approve"); - var have = _client.PlanResources(request); + var have = _client.PlanResources(request, _metadata); Assert.That(have.Action, Is.EqualTo("approve")); Assert.That(have.PolicyVersion, Is.EqualTo("20210210")); Assert.That(have.ResourceKind, Is.EqualTo("leave_request")); @@ -260,7 +262,7 @@ public void PlanResourcesValidation() ) .WithAction("approve"); - var have = _client.PlanResources(request); + var have = _client.PlanResources(request, _metadata); Assert.That(have.Action, Is.EqualTo("approve")); Assert.That(have.PolicyVersion, Is.EqualTo("20210210")); Assert.That(have.ResourceKind, Is.EqualTo("leave_request")); @@ -302,7 +304,7 @@ public void Playground() .WithActions("approve", "delete") ); - var have = _clientPlayground.CheckResources(request).Find("XX125"); + var have = _clientPlayground.CheckResources(request, _metadata).Find("XX125"); Assert.That(have.IsAllowed("approve"), Is.True); Assert.That(have.IsAllowed("delete"), Is.True); } @@ -333,7 +335,7 @@ public async Task CheckWithoutJwtAsync() .WithActions("approve", "view:public") ); - var have = (await _client.CheckResourcesAsync(request)).Find("XX125"); + var have = (await _client.CheckResourcesAsync(request, _metadata)).Find("XX125"); Assert.That(have.IsAllowed("view:public"), Is.True); Assert.That(have.IsAllowed("approve"), Is.False); diff --git a/src/Sdk/Cerbos/Sdk/Builder/CerbosClientBuilder.cs b/src/Sdk/Cerbos/Sdk/Builder/CerbosClientBuilder.cs index bd19452..098e835 100644 --- a/src/Sdk/Cerbos/Sdk/Builder/CerbosClientBuilder.cs +++ b/src/Sdk/Cerbos/Sdk/Builder/CerbosClientBuilder.cs @@ -3,8 +3,8 @@ using System; using System.IO; -using System.Threading.Tasks; using Grpc.Core; +using Grpc.Core.Interceptors; using Grpc.Net.Client; namespace Cerbos.Sdk.Builder @@ -20,6 +20,7 @@ public sealed class CerbosClientBuilder private StreamReader TlsCertificate { get; set; } private StreamReader TlsKey { get; set; } private GrpcChannelOptions GrpcChannelOptions { get; set; } + private Metadata Metadata { get; set; } private CerbosClientBuilder(string target) { Target = target; @@ -30,6 +31,12 @@ public static CerbosClientBuilder ForTarget(string target) return new CerbosClientBuilder(target); } + public CerbosClientBuilder WithMetadata(Metadata headers) + { + Metadata = headers; + return this; + } + public CerbosClientBuilder WithPlaintext() { Plaintext = true; return this; @@ -81,14 +88,15 @@ public CerbosClient Build() ); } - CallCredentials callCredentials = null; + Metadata combined = Metadata; if (!string.IsNullOrEmpty(PlaygroundInstanceId)) - { - callCredentials = CallCredentials.FromInterceptor((context, metadata) => - { - metadata.Add(PlaygroundInstanceHeader, PlaygroundInstanceId.Trim()); - return Task.CompletedTask; - }); + { + combined = Utility.Metadata.Merge( + Metadata, + new Metadata { + { PlaygroundInstanceHeader, PlaygroundInstanceId.Trim() } + } + ); } SslCredentials sslCredentials = null; @@ -103,39 +111,21 @@ public CerbosClient Build() sslCredentials = new SslCredentials(CaCertificate.ReadToEnd()); } } - - GrpcChannel channel; - if (Plaintext) + + var grpcChannelOptions = GrpcChannelOptions ?? new GrpcChannelOptions(); + if (sslCredentials != null) { - if (GrpcChannelOptions != null) - { - channel = GrpcChannel.ForAddress(Target, GrpcChannelOptions); - } - else - { - channel = GrpcChannel.ForAddress(Target); - } + grpcChannelOptions.Credentials = sslCredentials; } - else + else if (!Plaintext) { - GrpcChannelOptions grpcChannelOptions = GrpcChannelOptions ?? new GrpcChannelOptions(); - if (callCredentials != null && sslCredentials != null) - { - grpcChannelOptions.Credentials = ChannelCredentials.Create(sslCredentials, callCredentials); - } - else if (sslCredentials != null) - { - grpcChannelOptions.Credentials = sslCredentials; - } - else if (callCredentials != null) - { - grpcChannelOptions.Credentials = ChannelCredentials.Create(ChannelCredentials.SecureSsl, callCredentials); - } - - channel = GrpcChannel.ForAddress(Target, grpcChannelOptions); + grpcChannelOptions.Credentials = ChannelCredentials.SecureSsl; } - return new CerbosClient(new Api.V1.Svc.CerbosService.CerbosServiceClient(channel)); + var grpcChannel = GrpcChannel + .ForAddress(Target, grpcChannelOptions) + .Intercept(); + return new CerbosClient(new Api.V1.Svc.CerbosService.CerbosServiceClient(grpcChannel), combined); } } } \ No newline at end of file diff --git a/src/Sdk/Cerbos/Sdk/CerbosClient.cs b/src/Sdk/Cerbos/Sdk/CerbosClient.cs index f004093..668723b 100644 --- a/src/Sdk/Cerbos/Sdk/CerbosClient.cs +++ b/src/Sdk/Cerbos/Sdk/CerbosClient.cs @@ -4,6 +4,7 @@ using System; using System.Threading.Tasks; using Cerbos.Sdk.Response; +using Grpc.Core; namespace Cerbos.Sdk { @@ -13,20 +14,22 @@ namespace Cerbos.Sdk public sealed class CerbosClient { private Api.V1.Svc.CerbosService.CerbosServiceClient CerbosServiceClient { get; } - - public CerbosClient(Api.V1.Svc.CerbosService.CerbosServiceClient cerbosServiceClient) + private readonly Metadata _metadata; + + public CerbosClient(Api.V1.Svc.CerbosService.CerbosServiceClient cerbosServiceClient, Metadata metadata = null) { CerbosServiceClient = cerbosServiceClient; + _metadata = metadata; } /// /// Send a request consisting of a principal, resource(s) & action(s) to see if the principal is authorized to do the action(s) on the resource(s). /// - public CheckResourcesResponse CheckResources(Builder.CheckResourcesRequest request) + public CheckResourcesResponse CheckResources(Builder.CheckResourcesRequest request, Metadata headers = null) { try { - return new CheckResourcesResponse(CerbosServiceClient.CheckResources(request.ToCheckResourcesRequest())); + return new CheckResourcesResponse(CerbosServiceClient.CheckResources(request.ToCheckResourcesRequest(), Utility.Metadata.Merge(_metadata, headers))); } catch (Exception e) { @@ -37,12 +40,12 @@ public CheckResourcesResponse CheckResources(Builder.CheckResourcesRequest reque /// /// Send an async request consisting of a principal, resource(s) & action(s) to see if the principal is authorized to do the action(s) on the resource(s). /// - public Task CheckResourcesAsync(Builder.CheckResourcesRequest request) + public Task CheckResourcesAsync(Builder.CheckResourcesRequest request, Metadata headers = null) { try { return CerbosServiceClient - .CheckResourcesAsync(request.ToCheckResourcesRequest()) + .CheckResourcesAsync(request.ToCheckResourcesRequest(), Utility.Metadata.Merge(_metadata, headers)) .ResponseAsync .ContinueWith( r => new CheckResourcesResponse(r.Result) @@ -57,11 +60,11 @@ public Task CheckResourcesAsync(Builder.CheckResourcesRe /// /// Obtain a query plan for performing the given action on the given resource kind. /// - public PlanResourcesResponse PlanResources(Builder.PlanResourcesRequest request) + public PlanResourcesResponse PlanResources(Builder.PlanResourcesRequest request, Metadata headers = null) { try { - return new PlanResourcesResponse(CerbosServiceClient.PlanResources(request.ToPlanResourcesRequest())); + return new PlanResourcesResponse(CerbosServiceClient.PlanResources(request.ToPlanResourcesRequest(), Utility.Metadata.Merge(_metadata, headers))); } catch (Exception e) { @@ -72,12 +75,12 @@ public PlanResourcesResponse PlanResources(Builder.PlanResourcesRequest request) /// /// Obtain a query plan for performing the given action on the given resource kind. /// - public Task PlanResourcesAsync(Builder.PlanResourcesRequest request) + public Task PlanResourcesAsync(Builder.PlanResourcesRequest request, Metadata headers = null) { try { return CerbosServiceClient - .PlanResourcesAsync(request.ToPlanResourcesRequest()) + .PlanResourcesAsync(request.ToPlanResourcesRequest(), Utility.Metadata.Merge(_metadata, headers)) .ResponseAsync .ContinueWith( r => new PlanResourcesResponse(r.Result) diff --git a/src/Sdk/Cerbos/Sdk/Utility/Metadata.cs b/src/Sdk/Cerbos/Sdk/Utility/Metadata.cs new file mode 100644 index 0000000..561250b --- /dev/null +++ b/src/Sdk/Cerbos/Sdk/Utility/Metadata.cs @@ -0,0 +1,36 @@ +namespace Cerbos.Sdk.Utility +{ + public static class Metadata + { + public static Grpc.Core.Metadata Merge(Grpc.Core.Metadata first, Grpc.Core.Metadata second) + { + if (first == null && second == null) + { + return null; + } + + if (first != null && second == null) + { + return first; + } + + if (first == null) + { + return second; + } + + Grpc.Core.Metadata combined = new Grpc.Core.Metadata(); + foreach (var m in first) + { + combined.Add(m); + } + + foreach (var m in second) + { + combined.Add(m); + } + + return combined; + } + } +} \ No newline at end of file