From 63216c9d50c45efbe1093026bc86d036e8491be1 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 8 Aug 2024 19:39:53 -0400 Subject: [PATCH 01/19] added azure_credentials_provider.h to the public headers update rl_sim to use azure_credentials_provider with DefaultAzureCredential added support for the azure_credentials_provider in rl.net.native and rl.net added cmake checks for dotnet-t4 --- .../cs/common/codegen/TextTemplate.targets | 16 +- bindings/cs/rl.net.native/CMakeLists.txt | 9 ++ .../rl.net.native/rl.net.azure_factories.cc | 47 ++++++ .../cs/rl.net.native/rl.net.azure_factories.h | 33 ++++ bindings/cs/rl.net/ApiStatus.cs | 12 +- bindings/cs/rl.net/CALoop.cs | 24 +-- bindings/cs/rl.net/CBLoop.cs | 24 +-- bindings/cs/rl.net/CCBLoop.cs | 40 ++--- bindings/cs/rl.net/CMakeLists.txt | 10 ++ bindings/cs/rl.net/Configuration.cs | 10 +- .../cs/rl.net/ContinuousActionResponse.cs | 12 +- bindings/cs/rl.net/DecisionResponse.cs | 26 ++-- bindings/cs/rl.net/FactoryContext.cs | 8 +- bindings/cs/rl.net/LiveModel.cs | 54 +++---- bindings/cs/rl.net/MultiSlotResponse.cs | 24 +-- .../cs/rl.net/MultiSlotResponseDetailed.cs | 20 +-- bindings/cs/rl.net/Native/Global.cs | 2 +- bindings/cs/rl.net/Native/NativeImports.cs | 9 ++ bindings/cs/rl.net/NativeCallbacks.cs | 3 + bindings/cs/rl.net/OAuthCredentialProvider.cs | 143 ++++++++++++++++++ bindings/cs/rl.net/RankingResponse.cs | 22 +-- bindings/cs/rl.net/SharedBuffer.cs | 8 +- bindings/cs/rl.net/SlatesLoop.cs | 32 ++-- bindings/cs/rl.net/SlotRanking.cs | 20 +-- examples/rl_sim_cpp/CMakeLists.txt | 5 - examples/rl_sim_cpp/azure_credentials.cc | 67 -------- examples/rl_sim_cpp/azure_credentials.h | 30 ---- examples/rl_sim_cpp/main.cc | 3 +- examples/rl_sim_cpp/rl_sim.cc | 6 +- examples/rl_sim_cpp/rl_sim.h | 11 +- include/azure_credentials_provider.h | 80 ++++++++++ rlclientlib/CMakeLists.txt | 1 + 32 files changed, 529 insertions(+), 282 deletions(-) create mode 100644 bindings/cs/rl.net.native/rl.net.azure_factories.cc create mode 100644 bindings/cs/rl.net.native/rl.net.azure_factories.h create mode 100644 bindings/cs/rl.net/Native/NativeImports.cs create mode 100644 bindings/cs/rl.net/OAuthCredentialProvider.cs delete mode 100644 examples/rl_sim_cpp/azure_credentials.cc delete mode 100644 examples/rl_sim_cpp/azure_credentials.h create mode 100644 include/azure_credentials_provider.h diff --git a/bindings/cs/common/codegen/TextTemplate.targets b/bindings/cs/common/codegen/TextTemplate.targets index 453257a65..e335aab2c 100644 --- a/bindings/cs/common/codegen/TextTemplate.targets +++ b/bindings/cs/common/codegen/TextTemplate.targets @@ -5,8 +5,9 @@ - dotnet t4 - t4 + + + t4 @@ -17,7 +18,8 @@ @(TextTransformParameter -> '-a !!%(Identity)!%(Value)', ' ') - $(T4Command) "%(TextTemplate.Identity)" -o "$([System.IO.Path]::Combine($(IntermediateOutputPath),%(TextTemplate.FileName).T4Generated.cs))" $(Parameters) + $(T4Command) "%(TextTemplate.Identity)" -o "$([System.IO.Path]::Combine($(IntermediateOutputPath),%(TextTemplate.FileName).T4Generated.cs))" $(Parameters) + $(T4Command) "%(TextTemplate.Identity)" -out "$([System.IO.Path]::Combine($(IntermediateOutputPath),%(TextTemplate.FileName).T4Generated.cs))" $(Parameters) @@ -25,6 +27,14 @@ + + + + + + + + diff --git a/bindings/cs/rl.net.native/CMakeLists.txt b/bindings/cs/rl.net.native/CMakeLists.txt index c7dba0dd5..f81391226 100644 --- a/bindings/cs/rl.net.native/CMakeLists.txt +++ b/bindings/cs/rl.net.native/CMakeLists.txt @@ -45,6 +45,15 @@ set(rl_net_native_HEADERS rl.net.slot_ranking.h ) +if(RL_LINK_AZURE_LIBS) + list(APPEND rl_net_native_HEADERS + rl.net.azure_factories.h + ) + list(APPEND rl_net_native_SOURCES + rl.net.azure_factories.cc + ) +endif() + source_group("Sources" FILES ${rl_net_native_SOURCES}) source_group("Headers" FILES ${rl_net_native_HEADERS}) diff --git a/bindings/cs/rl.net.native/rl.net.azure_factories.cc b/bindings/cs/rl.net.native/rl.net.azure_factories.cc new file mode 100644 index 000000000..d850e4922 --- /dev/null +++ b/bindings/cs/rl.net.native/rl.net.azure_factories.cc @@ -0,0 +1,47 @@ +#include "rl.net.azure_factories.h" + +namespace rl_net_native +{ +rl_net_native::azure_factory_oauth_callback_t g_oauth_callback = nullptr; +rl_net_native::azure_factory_oauth_callback_complete_t g_oauth_callback_complete = nullptr; + +static int azure_factory_oauth_callback(const std::vector& scopes, + std::string& oauth_token, std::chrono::system_clock::time_point& token_expiry) +{ + if (g_oauth_callback == nullptr) { + return -1; + } + // create a null terminated array of scope string pointers + // these are pointers are readonly and owned by the caller + std::vector native_scopes; + native_scopes.reserve(scopes.size() + 1); + for (int i = 0; i < scopes.size(); ++i) + { + native_scopes.push_back(scopes[i].c_str()); + } + native_scopes.push_back(nullptr); + // we expect to get a pointer to a null terminated string + // it's allocated on the managed heap, so we can't free it here + // instead, we will pass it back to the managed code after we copy it + char *oauth_token_ptr = nullptr; + std::int64_t expiryUnixTime = 0; + auto ret = g_oauth_callback(native_scopes.data(), &oauth_token_ptr, &expiryUnixTime); + if (ret == reinforcement_learning::error_code::success) { + oauth_token = oauth_token_ptr; + token_expiry = std::chrono::system_clock::from_time_t(expiryUnixTime); + } + g_oauth_callback_complete(oauth_token_ptr, reinforcement_learning::error_code::success); + return ret; +} +} + +API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback, + rl_net_native::azure_factory_oauth_callback_complete_t completion) +{ + if (rl_net_native::g_oauth_callback == nullptr) { + rl_net_native::g_oauth_callback = callback; + reinforcement_learning::register_default_factories_callback( + reinforcement_learning::oauth_callback_t{rl_net_native::azure_factory_oauth_callback}); + rl_net_native::g_oauth_callback_complete = completion; + } +} diff --git a/bindings/cs/rl.net.native/rl.net.azure_factories.h b/bindings/cs/rl.net.native/rl.net.azure_factories.h new file mode 100644 index 000000000..1ae564c14 --- /dev/null +++ b/bindings/cs/rl.net.native/rl.net.azure_factories.h @@ -0,0 +1,33 @@ +#pragma once + +#include "rl.net.native.h" +#include "factory_resolver.h" + +#include + +namespace rl_net_native +{ +// Callback function to be called by the C# code to get the OAuth token +// The callback function should return 0 on success and non-zero on failure (see reinforcement_learning::error_code) +// prototype - azure_factory_oauth_callback(const char** scopes, char** token, std::int64_t* expiryUnixTime) +// scopes - null terminated array of scope strings +// token - out pointer to a null terminated string allocated on the managed heap +// expiryUnixTime - pointer to the expiry time of the token in Unix time +typedef int (*azure_factory_oauth_callback_t)(const char**, char**, std::int64_t*); + +// Callback function to be called by the C# code to signal the completion of the OAuth token request +// this provides the C# code with the opportunity to free the memory allocated for the token +typedef void (*azure_factory_oauth_callback_complete_t)(char*, int); +} // namespace rl_net_native + +extern "C" +{ + // Register the callback function to be called by the C++ code to get the OAuth token + // both the callback and completion functions must be registered before any calls to the Azure factories are made + // typically, this function should be called once during the initialization of the application + // callback - the callback function to be called by the C++ code to get the OAuth token + // completion - the callback function to be called by the C++ code to signal the completion of the OAuth token request + API void RegisterDefaultFactoriesCallback( + rl_net_native::azure_factory_oauth_callback_t callback, + rl_net_native::azure_factory_oauth_callback_complete_t completion); +} diff --git a/bindings/cs/rl.net/ApiStatus.cs b/bindings/cs/rl.net/ApiStatus.cs index 68d6af004..63c5b6182 100644 --- a/bindings/cs/rl.net/ApiStatus.cs +++ b/bindings/cs/rl.net/ApiStatus.cs @@ -8,22 +8,22 @@ namespace Rl.Net { public sealed class ApiStatus : NativeObject { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateApiStatus(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteApiStatus(IntPtr config); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr GetApiStatusErrorMessage(IntPtr status); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int GetApiStatusErrorCode(IntPtr status); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void UpdateApiStatusSafe(IntPtr status, int error_code, IntPtr message); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void ClearApiStatusSafe(IntPtr status); public ApiStatus() : base(new New(CreateApiStatus), new Delete(DeleteApiStatus)) diff --git a/bindings/cs/rl.net/CALoop.cs b/bindings/cs/rl.net/CALoop.cs index a4c88178b..5b586f7f7 100644 --- a/bindings/cs/rl.net/CALoop.cs +++ b/bindings/cs/rl.net/CALoop.cs @@ -10,16 +10,16 @@ namespace Native // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation internal static partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateCALoop(IntPtr config, IntPtr factoryContext); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteCALoop(IntPtr caLoop); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int CALoopInit(IntPtr caLoop, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "CALoopRequestContinuousAction")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopRequestContinuousAction")] private static extern int CALoopRequestContinuousActionNative(IntPtr caLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus); internal static Func CALoopRequestContinuousActionOverride { get; set; } @@ -34,7 +34,7 @@ public static int CALoopRequestContinuousAction(IntPtr caLoop, IntPtr eventId, I return CALoopRequestContinuousActionNative(caLoop, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CALoopRequestContinuousActionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopRequestContinuousActionWithFlags")] private static extern int CALoopRequestContinuousActionWithFlagsNative(IntPtr caLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr continuousActionResponse, IntPtr apiStatus); internal static Func CALoopRequestContinuousActionWithFlagsOverride { get; set; } @@ -49,7 +49,7 @@ public static int CALoopRequestContinuousActionWithFlags(IntPtr caLoop, IntPtr e return CALoopRequestContinuousActionWithFlagsNative(caLoop, eventId, contextJson, contextJsonSize, flags, continuousActionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CALoopReportActionTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportActionTaken")] private static extern int CALoopReportActionTakenNative(IntPtr caLoop, IntPtr eventId, IntPtr apiStatus); internal static Func CALoopReportActionTakenOverride { get; set; } @@ -64,7 +64,7 @@ public static int CALoopReportActionTaken(IntPtr caLoop, IntPtr eventId, IntPtr return CALoopReportActionTakenNative(caLoop, eventId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CALoopReportActionMultiIdTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportActionMultiIdTaken")] private static extern int CALoopReportActionTakenMultiIdNative(IntPtr caLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus); internal static Func CALoopReportActionTakenMultiIdOverride { get; set; } @@ -79,7 +79,7 @@ public static int CALoopReportActionMultiIdTaken(IntPtr caLoop, IntPtr primaryId return CALoopReportActionTakenMultiIdNative(caLoop, primaryId, secondaryId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CALoopReportOutcomeF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportOutcomeF")] private static extern int CALoopReportOutcomeFNative(IntPtr caLoop, IntPtr eventId, float outcome, IntPtr apiStatus); internal static Func CALoopReportOutcomeFOverride { get; set; } @@ -94,7 +94,7 @@ public static int CALoopReportOutcomeF(IntPtr caLoop, IntPtr eventId, float outc return CALoopReportOutcomeFNative(caLoop, eventId, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CALoopReportOutcomeJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportOutcomeJson")] private static extern int CALoopReportOutcomeJsonNative(IntPtr caLoop, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus); internal static Func CALoopReportOutcomeJsonOverride { get; set; } @@ -109,13 +109,13 @@ public static int CALoopReportOutcomeJson(IntPtr caLoop, IntPtr eventId, IntPtr return CALoopReportOutcomeJsonNative(caLoop, eventId, outcomeJson, apiStatus); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int CALoopRefreshModel(IntPtr caLoop, IntPtr apiStatus); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void CALoopSetCallback(IntPtr caLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void CALoopSetTrace(IntPtr caLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); } } diff --git a/bindings/cs/rl.net/CBLoop.cs b/bindings/cs/rl.net/CBLoop.cs index 5b8846697..dd2ed086d 100644 --- a/bindings/cs/rl.net/CBLoop.cs +++ b/bindings/cs/rl.net/CBLoop.cs @@ -10,16 +10,16 @@ namespace Native // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation internal static partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateCBLoop(IntPtr config, IntPtr factoryContext); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteCBLoop(IntPtr cbLoop); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int CBLoopInit(IntPtr cbLoop, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRank")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CBLoopChooseRank")] private static extern int CBLoopChooseRankNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); internal static Func CBLoopChooseRankOverride { get; set; } @@ -34,7 +34,7 @@ public static int CBLoopChooseRank(IntPtr cbLoop, IntPtr eventId, IntPtr context return CBLoopChooseRankNative(cbLoop, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CBLoopChooseRankWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CBLoopChooseRankWithFlags")] private static extern int CBLoopChooseRankWithFlagsNative(IntPtr cbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); internal static Func CBLoopChooseRankWithFlagsOverride { get; set; } @@ -49,7 +49,7 @@ public static int CBLoopChooseRankWithFlags(IntPtr cbLoop, IntPtr eventId, IntPt return CBLoopChooseRankWithFlagsNative(cbLoop, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CBLoopReportActionTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CBLoopReportActionTaken")] private static extern int CBLoopReportActionTakenNative(IntPtr cbLoop, IntPtr eventId, IntPtr apiStatus); internal static Func CBLoopReportActionTakenOverride { get; set; } @@ -64,7 +64,7 @@ public static int CBLoopReportActionTaken(IntPtr cbLoop, IntPtr eventId, IntPtr return CBLoopReportActionTakenNative(cbLoop, eventId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CBLoopReportActionMultiIdTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CBLoopReportActionMultiIdTaken")] private static extern int CBLoopReportActionTakenMultiIdNative(IntPtr cbLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus); internal static Func CBLoopReportActionTakenMultiIdOverride { get; set; } @@ -79,7 +79,7 @@ public static int CBLoopReportActionMultiIdTaken(IntPtr cbLoop, IntPtr primaryId return CBLoopReportActionTakenMultiIdNative(cbLoop, primaryId, secondaryId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CBLoopReportOutcomeF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CBLoopReportOutcomeF")] private static extern int CBLoopReportOutcomeFNative(IntPtr cbLoop, IntPtr eventId, float outcome, IntPtr apiStatus); internal static Func CBLoopReportOutcomeFOverride { get; set; } @@ -94,7 +94,7 @@ public static int CBLoopReportOutcomeF(IntPtr cbLoop, IntPtr eventId, float outc return CBLoopReportOutcomeFNative(cbLoop, eventId, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CBLoopReportOutcomeJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CBLoopReportOutcomeJson")] private static extern int CBLoopReportOutcomeJsonNative(IntPtr cbLoop, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus); internal static Func CBLoopReportOutcomeJsonOverride { get; set; } @@ -109,13 +109,13 @@ public static int CBLoopReportOutcomeJson(IntPtr cbLoop, IntPtr eventId, IntPtr return CBLoopReportOutcomeJsonNative(cbLoop, eventId, outcomeJson, apiStatus); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int CBLoopRefreshModel(IntPtr cbLoop, IntPtr apiStatus); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void CBLoopSetCallback(IntPtr cbLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void CBLoopSetTrace(IntPtr cbLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); } } diff --git a/bindings/cs/rl.net/CCBLoop.cs b/bindings/cs/rl.net/CCBLoop.cs index 75dc28632..265e93e05 100644 --- a/bindings/cs/rl.net/CCBLoop.cs +++ b/bindings/cs/rl.net/CCBLoop.cs @@ -10,16 +10,16 @@ namespace Native // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation internal static partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateCCBLoop(IntPtr config, IntPtr factoryContext); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteCCBLoop(IntPtr ccbLoop); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int CCBLoopInit(IntPtr ccbLoop, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestDecision")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestDecision")] private static extern int CCBLoopRequestDecisionNative(IntPtr ccbLoop, IntPtr contextJson, int contextJsonSize, IntPtr decisionResponse, IntPtr apiStatus); internal static Func CCBLoopRequestDecisionOverride { get; set; } @@ -34,7 +34,7 @@ public static int CCBLoopRequestDecision(IntPtr ccbLoop, IntPtr contextJson, int return CCBLoopRequestDecisionNative(ccbLoop, contextJson, contextJsonSize, decisionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestDecisionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestDecisionWithFlags")] private static extern int CCBLoopRequestDecisionWithFlagsNative(IntPtr ccbLoop, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr decisionResponse, IntPtr apiStatus); internal static Func CCBLoopRequestDecisionWithFlagsOverride { get; set; } @@ -49,7 +49,7 @@ public static int CCBLoopRequestDecisionWithFlags(IntPtr ccbLoop, IntPtr context return CCBLoopRequestDecisionWithFlagsNative(ccbLoop, contextJson, contextJsonSize, flags, decisionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestMultiSlotDecision")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestMultiSlotDecision")] private static extern int CCBLoopRequestMultiSlotDecisionNative(IntPtr ccbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus); internal static Func CCBLoopRequestMultiSlotDecisionOverride { get; set; } @@ -64,7 +64,7 @@ public static int CCBLoopRequestMultiSlotDecision(IntPtr ccbLoop, IntPtr eventId return CCBLoopRequestMultiSlotDecisionNative(ccbLoop, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestMultiSlotDecisionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestMultiSlotDecisionWithFlags")] private static extern int CCBLoopRequestMultiSlotDecisionWithFlagsNative(IntPtr ccbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr apiStatus); internal static Func CCBLoopRequestMultiSlotDecisionWithFlagsOverride { get; set; } @@ -79,7 +79,7 @@ public static int CCBLoopRequestMultiSlotDecisionWithFlags(IntPtr ccbLoop, IntPt return CCBLoopRequestMultiSlotDecisionWithFlagsNative(ccbLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestMultiSlotDecisionWithBaselineAndFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestMultiSlotDecisionWithBaselineAndFlags")] private static extern int CCBLoopRequestMultiSlotDecisionWithBaselineAndFlagsNative(IntPtr ccbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr baselineActions, IntPtr baselineActionsSize, IntPtr apiStatus); internal static Func CCBLoopRequestMultiSlotDecisionWithBaselineAndFlagsOverride { get; set; } @@ -94,7 +94,7 @@ public static int CCBLoopRequestMultiSlotDecisionWithBaselineAndFlags(IntPtr ccb return CCBLoopRequestMultiSlotDecisionWithBaselineAndFlagsNative(ccbLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestMultiSlotDecisionDetailed")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestMultiSlotDecisionDetailed")] private static extern int CCBLoopRequestMultiSlotDecisionDetailedNative(IntPtr ccbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); internal static Func CCBLoopRequestMultiSlotDecisionDetailedOverride { get; set; } @@ -109,7 +109,7 @@ public static int CCBLoopRequestMultiSlotDecisionDetailed(IntPtr ccbLoop, IntPtr return CCBLoopRequestMultiSlotDecisionDetailedNative(ccbLoop, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestMultiSlotDecisionDetailedWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestMultiSlotDecisionDetailedWithFlags")] private static extern int CCBLoopRequestMultiSlotDecisionDetailedWithFlagsNative(IntPtr ccbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); internal static Func CCBLoopRequestMultiSlotDecisionDetailedWithFlagsOverride { get; set; } @@ -124,7 +124,7 @@ public static int CCBLoopRequestMultiSlotDecisionDetailedWithFlags(IntPtr ccbLoo return CCBLoopRequestMultiSlotDecisionDetailedWithFlagsNative(ccbLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponseDetailed, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlags")] private static extern int CCBLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlagsNative(IntPtr ccbLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr baselineActions, IntPtr baselineActionsSize, IntPtr apiStatus); internal static Func CCBLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlagsOverride { get; set; } @@ -139,7 +139,7 @@ public static int CCBLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlags(In return CCBLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlagsNative(ccbLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponseDetailed, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopReportActionTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopReportActionTaken")] private static extern int CCBLoopReportActionTakenNative(IntPtr ccbLoop, IntPtr eventId, IntPtr apiStatus); internal static Func CCBLoopReportActionTakenOverride { get; set; } @@ -154,7 +154,7 @@ public static int CCBLoopReportActionTaken(IntPtr ccbLoop, IntPtr eventId, IntPt return CCBLoopReportActionTakenNative(ccbLoop, eventId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopReportActionMultiIdTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopReportActionMultiIdTaken")] private static extern int CCBLoopReportActionTakenMultiIdNative(IntPtr ccbLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus); internal static Func CCBLoopReportActionTakenMultiIdOverride { get; set; } @@ -169,7 +169,7 @@ public static int CCBLoopReportActionMultiIdTaken(IntPtr ccbLoop, IntPtr primary return CCBLoopReportActionTakenMultiIdNative(ccbLoop, primaryId, secondaryId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopReportOutcomeSlotF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopReportOutcomeSlotF")] private static extern int CCBLoopReportOutcomeSlotFNative(IntPtr ccbLoop, IntPtr eventId, uint slotIndex, float outcome, IntPtr apiStatus); internal static Func CCBLoopReportOutcomeSlotFOverride { get; set; } @@ -184,7 +184,7 @@ public static int CCBLoopReportOutcomeSlotF(IntPtr ccbLoop, IntPtr eventId, uint return CCBLoopReportOutcomeSlotFNative(ccbLoop, eventId, slotIndex, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopReportOutcomeSlotJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopReportOutcomeSlotJson")] private static extern int CCBLoopReportOutcomeSlotJsonNative(IntPtr ccbLoop, IntPtr eventId, uint slotIndex, IntPtr outcomeJson, IntPtr apiStatus); internal static Func CCBLoopReportOutcomeSlotJsonOverride { get; set; } @@ -199,7 +199,7 @@ public static int CCBLoopReportOutcomeSlotJson(IntPtr ccbLoop, IntPtr eventId, u return CCBLoopReportOutcomeSlotJsonNative(ccbLoop, eventId, slotIndex, outcomeJson, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopReportOutcomeSlotStringIdF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopReportOutcomeSlotStringIdF")] private static extern int CCBLoopReportOutcomeSlotStringIdFNative(IntPtr ccbLoop, IntPtr eventId, IntPtr slotId, float outcome, IntPtr apiStatus); internal static Func CCBLoopReportOutcomeSlotStringIdFOverride { get; set; } @@ -214,7 +214,7 @@ public static int CCBLoopReportOutcomeSlotStringIdF(IntPtr ccbLoop, IntPtr event return CCBLoopReportOutcomeSlotStringIdFNative(ccbLoop, eventId, slotId, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "CCBLoopReportOutcomeSlotStringIdJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CCBLoopReportOutcomeSlotStringIdJson")] private static extern int CCBLoopReportOutcomeSlotStringIdJsonNative(IntPtr ccbLoop, IntPtr eventId, IntPtr slotId, IntPtr outcomeJson, IntPtr apiStatus); internal static Func CCBLoopReportOutcomeSlotStringIdJsonOverride { get; set; } @@ -229,13 +229,13 @@ public static int CCBLoopReportOutcomeSlotStringIdJson(IntPtr ccbLoop, IntPtr ev return CCBLoopReportOutcomeSlotStringIdJsonNative(ccbLoop, eventId, slotId, outcomeJson, apiStatus); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int CCBLoopRefreshModel(IntPtr ccbLoop, IntPtr apiStatus); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void CCBLoopSetCallback(IntPtr ccbLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void CCBLoopSetTrace(IntPtr ccbLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); } } diff --git a/bindings/cs/rl.net/CMakeLists.txt b/bindings/cs/rl.net/CMakeLists.txt index 29885b155..e7258c6ee 100644 --- a/bindings/cs/rl.net/CMakeLists.txt +++ b/bindings/cs/rl.net/CMakeLists.txt @@ -1,7 +1,16 @@ +find_program(DOTNET_T4_EXECUTABLE NAMES t4) + +if (DOTNET_T4_EXECUTABLE) + message(STATUS "found dotnet-t4: ${DOTNET_T4_EXECUTABLE}") +else() + message(FATAL_ERROR "dotnet-t4 tool not found. install dotnet-t4 using: dotnet tool install -g dotnet-t4") +endif() + set(RL_NET_SOURCES Native/ErrorCallback.cs Native/GCHandleLifetime.cs Native/Global.cs + Native/NativeImports.cs Native/NativeObject.cs Native/SenderAdapter.cs Native/StringExtensions.cs @@ -9,6 +18,7 @@ set(RL_NET_SOURCES ActionFlags.cs ApiStatus.cs AsyncSender.cs + OAuthCredentialProvider.cs CALoop.cs CBLoop.cs CCBLoop.cs diff --git a/bindings/cs/rl.net/Configuration.cs b/bindings/cs/rl.net/Configuration.cs index 140c3c19e..c925a83a3 100644 --- a/bindings/cs/rl.net/Configuration.cs +++ b/bindings/cs/rl.net/Configuration.cs @@ -14,13 +14,13 @@ namespace Native // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation internal static partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateConfig(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteConfig(IntPtr config); - [DllImport("rlnetnative", EntryPoint = "LoadConfigurationFromJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LoadConfigurationFromJson")] private static extern int LoadConfigurationFromJsonNative(int jsonLength, IntPtr json, IntPtr config, IntPtr apiStatus); internal static Func LoadConfigurationFromJsonOverride { get; set; } @@ -35,7 +35,7 @@ public static int LoadConfigurationFromJson(int jsonLength, IntPtr json, IntPtr return LoadConfigurationFromJsonNative(jsonLength, json, config, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "ConfigurationSet")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "ConfigurationSet")] private static extern void ConfigurationSetNative(IntPtr config, IntPtr name, IntPtr value); internal static Action ConfigurationSetOverride { get; set; } @@ -51,7 +51,7 @@ public static void ConfigurationSet(IntPtr config, IntPtr name, IntPtr value) ConfigurationSetNative(config, name, value); } - [DllImport("rlnetnative", EntryPoint = "ConfigurationGet")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "ConfigurationGet")] private static extern IntPtr ConfigurationGetNative(IntPtr config, IntPtr name, IntPtr defVal); internal static Func ConfigurationGetOverride { get; set; } diff --git a/bindings/cs/rl.net/ContinuousActionResponse.cs b/bindings/cs/rl.net/ContinuousActionResponse.cs index 89ce76db5..6deaf9815 100644 --- a/bindings/cs/rl.net/ContinuousActionResponse.cs +++ b/bindings/cs/rl.net/ContinuousActionResponse.cs @@ -12,13 +12,13 @@ namespace Native { internal partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateContinuousActionResponse(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteContinuousActionResponse(IntPtr response); - [DllImport("rlnetnative", EntryPoint = "GetContinuousActionEventId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetContinuousActionEventId")] private static extern IntPtr GetContinuousActionEventIdNative(IntPtr response); internal static Func GetContinuousActionEventIdOverride { get; set; } @@ -33,7 +33,7 @@ public static IntPtr GetContinuousActionEventId(IntPtr response) return GetContinuousActionEventIdNative(response); } - [DllImport("rlnetnative", EntryPoint = "GetContinuousActionModelId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetContinuousActionModelId")] private static extern IntPtr GetContinuousActionModelIdNative(IntPtr response); internal static Func GetContinuousActionModelIdOverride { get; set; } @@ -48,10 +48,10 @@ public static IntPtr GetContinuousActionModelId(IntPtr response) return GetContinuousActionModelIdNative(response); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern float GetContinuousActionChosenAction(IntPtr response); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern float GetContinuousActionChosenActionPdfValue(IntPtr response); } } diff --git a/bindings/cs/rl.net/DecisionResponse.cs b/bindings/cs/rl.net/DecisionResponse.cs index fbe275447..7f58b23ee 100644 --- a/bindings/cs/rl.net/DecisionResponse.cs +++ b/bindings/cs/rl.net/DecisionResponse.cs @@ -12,25 +12,25 @@ namespace Native { internal partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateSlotResponse(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr GetSlotSlotId(IntPtr slotResponse); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int GetSlotActionId(IntPtr slotResponse); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern float GetSlotProbability(IntPtr slotResponse); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateDecisionResponse(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteDecisionResponse(IntPtr decisionResponse); - [DllImport("rlnetnative", EntryPoint = "GetDecisionModelId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetDecisionModelId")] private static extern IntPtr GetDecisionModelIdNative(IntPtr decisionResponse); internal static Func GetDecisionModelIdOverride { get; set; } @@ -47,7 +47,7 @@ public static IntPtr GetDecisionModelId(IntPtr decisionResponse) // TODO: CLS-compliance requires that we not publically expose unsigned types. // Probably not a big issue ("9e18 actions ought to be enough for anyone...") - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern UIntPtr GetDecisionSize(IntPtr decisionResponse); } } @@ -133,7 +133,7 @@ IEnumerator IEnumerable.GetEnumerator() private class DecisionResponseEnumerator : NativeObject, IEnumerator { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateDecisionEnumeratorAdapter(IntPtr decisionResponse); private static New BindConstructorArguments(DecisionResponse decisionResponse) @@ -148,16 +148,16 @@ private static New BindConstructorArguments(Decision }); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteDecisionEnumeratorAdapter(IntPtr decisionEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int DecisionEnumeratorInit(IntPtr decisionEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int DecisionEnumeratorMoveNext(IntPtr decisionEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr GetDecisionEnumeratorCurrent(IntPtr decisionEnumeratorAdapter); private bool initialState = true; diff --git a/bindings/cs/rl.net/FactoryContext.cs b/bindings/cs/rl.net/FactoryContext.cs index 16533f873..4c04705c9 100644 --- a/bindings/cs/rl.net/FactoryContext.cs +++ b/bindings/cs/rl.net/FactoryContext.cs @@ -9,16 +9,16 @@ namespace Rl.Net { public sealed class FactoryContext : NativeObject { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateFactoryContext(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateFactoryContextWithStaticModel(IntPtr vw_model, int len); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteFactoryContext(IntPtr context); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr SetFactoryContextBindingSenderFactory(IntPtr context, sender_create_fn create_Fn, sender_vtable vtable); public FactoryContext() : base(new New(CreateFactoryContext), new Delete(DeleteFactoryContext)) diff --git a/bindings/cs/rl.net/LiveModel.cs b/bindings/cs/rl.net/LiveModel.cs index 2e829d011..624fddeba 100644 --- a/bindings/cs/rl.net/LiveModel.cs +++ b/bindings/cs/rl.net/LiveModel.cs @@ -11,16 +11,16 @@ namespace Native // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation internal static partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateLiveModel(IntPtr config, IntPtr factoryContext); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteLiveModel(IntPtr liveModel); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int LiveModelInit(IntPtr liveModel, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "LiveModelChooseRank")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelChooseRank")] private static extern int LiveModelChooseRankNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr rankingResponse, IntPtr apiStatus); internal static Func LiveModelChooseRankOverride { get; set; } @@ -35,7 +35,7 @@ public static int LiveModelChooseRank(IntPtr liveModel, IntPtr eventId, IntPtr c return LiveModelChooseRankNative(liveModel, eventId, contextJson, contextJsonSize, rankingResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelChooseRankWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelChooseRankWithFlags")] private static extern int LiveModelChooseRankWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr rankingResponse, IntPtr apiStatus); internal static Func LiveModelChooseRankWithFlagsOverride { get; set; } @@ -50,7 +50,7 @@ public static int LiveModelChooseRankWithFlags(IntPtr liveModel, IntPtr eventId, return LiveModelChooseRankWithFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, rankingResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestContinuousAction")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestContinuousAction")] private static extern int LiveModelRequestContinuousActionNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus); internal static Func LiveModelRequestContinuousActionOverride { get; set; } @@ -65,7 +65,7 @@ public static int LiveModelRequestContinuousAction(IntPtr liveModel, IntPtr even return LiveModelRequestContinuousActionNative(liveModel, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestContinuousActionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestContinuousActionWithFlags")] private static extern int LiveModelRequestContinuousActionWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr continuousActionResponse, IntPtr apiStatus); internal static Func LiveModelRequestContinuousActionWithFlagsOverride { get; set; } @@ -80,7 +80,7 @@ public static int LiveModelRequestContinuousActionWithFlags(IntPtr liveModel, In return LiveModelRequestContinuousActionWithFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, continuousActionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestDecision")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestDecision")] private static extern int LiveModelRequestDecisionNative(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, IntPtr decisionResponse, IntPtr apiStatus); internal static Func LiveModelRequestDecisionOverride { get; set; } @@ -95,7 +95,7 @@ public static int LiveModelRequestDecision(IntPtr liveModel, IntPtr contextJson, return LiveModelRequestDecisionNative(liveModel, contextJson, contextJsonSize, decisionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestDecisionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestDecisionWithFlags")] private static extern int LiveModelRequestDecisionWithFlagsNative(IntPtr liveModel, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr decisionResponse, IntPtr apiStatus); internal static Func LiveModelRequestDecisionWithFlagsOverride { get; set; } @@ -110,7 +110,7 @@ public static int LiveModelRequestDecisionWithFlags(IntPtr liveModel, IntPtr con return LiveModelRequestDecisionWithFlagsNative(liveModel, contextJson, contextJsonSize, flags, decisionResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecision")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestMultiSlotDecision")] private static extern int LiveModelRequestMultiSlotDecisionNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus); internal static Func LiveModelRequestMultiSlotDecisionOverride { get; set; } @@ -125,7 +125,7 @@ public static int LiveModelRequestMultiSlotDecision(IntPtr liveModel, IntPtr eve return LiveModelRequestMultiSlotDecisionNative(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestMultiSlotDecisionWithFlags")] private static extern int LiveModelRequestMultiSlotDecisionWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr apiStatus); internal static Func LiveModelRequestMultiSlotDecisionWithFlagsOverride { get; set; } @@ -140,7 +140,7 @@ public static int LiveModelRequestMultiSlotDecisionWithFlags(IntPtr liveModel, I return LiveModelRequestMultiSlotDecisionWithFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionWithBaselineAndFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestMultiSlotDecisionWithBaselineAndFlags")] private static extern int LiveModelRequestMultiSlotDecisionWithBaselineAndFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr baselineActions, IntPtr baselineActionsSize, IntPtr apiStatus); internal static Func LiveModelRequestMultiSlotDecisionWithBaselineAndFlagsOverride { get; set; } @@ -155,7 +155,7 @@ public static int LiveModelRequestMultiSlotDecisionWithBaselineAndFlags(IntPtr l return LiveModelRequestMultiSlotDecisionWithBaselineAndFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionDetailed")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestMultiSlotDecisionDetailed")] private static extern int LiveModelRequestMultiSlotDecisionDetailedNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); internal static Func LiveModelRequestMultiSlotDecisionDetailedOverride { get; set; } @@ -170,7 +170,7 @@ public static int LiveModelRequestMultiSlotDecisionDetailed(IntPtr liveModel, In return LiveModelRequestMultiSlotDecisionDetailedNative(liveModel, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionDetailedWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestMultiSlotDecisionDetailedWithFlags")] private static extern int LiveModelRequestMultiSlotDecisionDetailedWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); internal static Func LiveModelRequestMultiSlotDecisionDetailedWithFlagsOverride { get; set; } @@ -185,7 +185,7 @@ public static int LiveModelRequestMultiSlotDecisionDetailedWithFlags(IntPtr live return LiveModelRequestMultiSlotDecisionDetailedWithFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, multiSlotResponseDetailed, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestMultiSlotDecisionDetailedWithBaselineAndFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestMultiSlotDecisionDetailedWithBaselineAndFlags")] private static extern int LiveModelRequestMultiSlotDecisionDetailedWithBaselineAndFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr baselineActions, IntPtr baselineActionsSize, IntPtr apiStatus); internal static Func LiveModelRequestMultiSlotDecisionDetailedWithBaselineAndFlagsOverride { get; set; } @@ -200,7 +200,7 @@ public static int LiveModelRequestMultiSlotDecisionDetailedWithBaselineAndFlags( return LiveModelRequestMultiSlotDecisionDetailedWithBaselineAndFlagsNative(liveModel, eventId, contextJson, contextJsonSize, flags, multiSlotResponseDetailed, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelRequestEpisodicDecisionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelRequestEpisodicDecisionWithFlags")] private static extern int LiveModelRequestEpisodicDecisionWithFlagsNative(IntPtr liveModel, IntPtr eventId, IntPtr previousEventId, IntPtr contextJson, uint flags, IntPtr rankingResponse, IntPtr episodes, IntPtr apiStatus); internal static Func LiveModelRequestEpisodicDecisionWithFlagsOverride { get; set; } @@ -215,7 +215,7 @@ public static int LiveModelRequestEpisodicDecisionWithFlags(IntPtr liveModel, In return LiveModelRequestEpisodicDecisionWithFlagsNative(liveModel, eventId, previousEventId, contextJson, flags, rankingResponse, episodes, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportActionTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportActionTaken")] private static extern int LiveModelReportActionTakenNative(IntPtr liveModel, IntPtr eventId, IntPtr apiStatus); internal static Func LiveModelReportActionTakenOverride { get; set; } @@ -230,7 +230,7 @@ public static int LiveModelReportActionTaken(IntPtr liveModel, IntPtr eventId, I return LiveModelReportActionTakenNative(liveModel, eventId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportActionMultiIdTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportActionMultiIdTaken")] private static extern int LiveModelReportActionTakenMultiIdNative(IntPtr liveModel, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus); internal static Func LiveModelReportActionTakenMultiIdOverride { get; set; } @@ -245,7 +245,7 @@ public static int LiveModelReportActionMultiIdTaken(IntPtr liveModel, IntPtr pri return LiveModelReportActionTakenMultiIdNative(liveModel, primaryId, secondaryId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportOutcomeF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportOutcomeF")] private static extern int LiveModelReportOutcomeFNative(IntPtr liveModel, IntPtr eventId, float outcome, IntPtr apiStatus); internal static Func LiveModelReportOutcomeFOverride { get; set; } @@ -260,7 +260,7 @@ public static int LiveModelReportOutcomeF(IntPtr liveModel, IntPtr eventId, floa return LiveModelReportOutcomeFNative(liveModel, eventId, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportOutcomeJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportOutcomeJson")] private static extern int LiveModelReportOutcomeJsonNative(IntPtr liveModel, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus); internal static Func LiveModelReportOutcomeJsonOverride { get; set; } @@ -275,7 +275,7 @@ public static int LiveModelReportOutcomeJson(IntPtr liveModel, IntPtr eventId, I return LiveModelReportOutcomeJsonNative(liveModel, eventId, outcomeJson, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportOutcomeSlotF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportOutcomeSlotF")] private static extern int LiveModelReportOutcomeSlotFNative(IntPtr liveModel, IntPtr eventId, uint slotIndex, float outcome, IntPtr apiStatus); internal static Func LiveModelReportOutcomeSlotFOverride { get; set; } @@ -290,7 +290,7 @@ public static int LiveModelReportOutcomeSlotF(IntPtr liveModel, IntPtr eventId, return LiveModelReportOutcomeSlotFNative(liveModel, eventId, slotIndex, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportOutcomeSlotJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportOutcomeSlotJson")] private static extern int LiveModelReportOutcomeSlotJsonNative(IntPtr liveModel, IntPtr eventId, uint slotIndex, IntPtr outcomeJson, IntPtr apiStatus); internal static Func LiveModelReportOutcomeSlotJsonOverride { get; set; } @@ -305,7 +305,7 @@ public static int LiveModelReportOutcomeSlotJson(IntPtr liveModel, IntPtr eventI return LiveModelReportOutcomeSlotJsonNative(liveModel, eventId, slotIndex, outcomeJson, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportOutcomeSlotStringIdF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportOutcomeSlotStringIdF")] private static extern int LiveModelReportOutcomeSlotStringIdFNative(IntPtr liveModel, IntPtr eventId, IntPtr slotId, float outcome, IntPtr apiStatus); internal static Func LiveModelReportOutcomeSlotStringIdFOverride { get; set; } @@ -320,7 +320,7 @@ public static int LiveModelReportOutcomeSlotStringIdF(IntPtr liveModel, IntPtr e return LiveModelReportOutcomeSlotStringIdFNative(liveModel, eventId, slotId, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "LiveModelReportOutcomeSlotStringIdJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "LiveModelReportOutcomeSlotStringIdJson")] private static extern int LiveModelReportOutcomeSlotStringIdJsonNative(IntPtr liveModel, IntPtr eventId, IntPtr slotId, IntPtr outcomeJson, IntPtr apiStatus); internal static Func LiveModelReportOutcomeSlotStringIdJsonOverride { get; set; } @@ -335,13 +335,13 @@ public static int LiveModelReportOutcomeSlotStringIdJson(IntPtr liveModel, IntPt return LiveModelReportOutcomeSlotStringIdJsonNative(liveModel, eventId, slotId, outcomeJson, apiStatus); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int LiveModelRefreshModel(IntPtr liveModel, IntPtr apiStatus); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void LiveModelSetCallback(IntPtr liveModel, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void LiveModelSetTrace(IntPtr liveModel, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); } } diff --git a/bindings/cs/rl.net/MultiSlotResponse.cs b/bindings/cs/rl.net/MultiSlotResponse.cs index 26b7cd485..777c62d23 100644 --- a/bindings/cs/rl.net/MultiSlotResponse.cs +++ b/bindings/cs/rl.net/MultiSlotResponse.cs @@ -12,19 +12,19 @@ namespace Native { internal partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int GetSlotEntryActionId(IntPtr slotEntryResponse); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern float GetSlotEntryProbability(IntPtr slotEntryResponse); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateMultiSlotResponse(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteMultiSlotResponse(IntPtr multiSlotResponse); - [DllImport("rlnetnative", EntryPoint = "GetMultiSlotModelId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetMultiSlotModelId")] private static extern IntPtr GetMultiSlotModelIdNative(IntPtr multiSlotResponse); internal static Func GetMultiSlotModelIdOverride { get; set; } @@ -40,7 +40,7 @@ public static IntPtr GetMultiSlotModelId(IntPtr multiSlotResponse) } - [DllImport("rlnetnative", EntryPoint = "GetMultiSlotEventId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetMultiSlotEventId")] private static extern IntPtr GetMultiSlotEventIdNative(IntPtr multiSlotResponse); internal static Func GetMultiSlotEventIdOverride { get; set; } @@ -57,7 +57,7 @@ public static IntPtr GetMultiSlotEventId(IntPtr multiSlotResponse) // TODO: CLS-compliance requires that we not publically expose unsigned types. // Probably not a big issue ("9e18 actions ought to be enough for anyone...") - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern UIntPtr GetMultiSlotSize(IntPtr decisionResponse); } } @@ -144,7 +144,7 @@ IEnumerator IEnumerable.GetEnumerator() private class MultiSlotResponseEnumerator : NativeObject, IEnumerator { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateMultiSlotEnumeratorAdapter(IntPtr multiSlotResponse); private static New BindConstructorArguments(MultiSlotResponse multiSlotResponse) @@ -158,16 +158,16 @@ private static New BindConstructorArguments(MultiSl }); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteMultiSlotEnumeratorAdapter(IntPtr multiSlotEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int MultiSlotEnumeratorInit(IntPtr multiSlotEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int MultiSlotEnumeratorMoveNext(IntPtr multiSlotEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr GetMultiSlotEnumeratorCurrent(IntPtr multiSlotEnumeratorAdapter); private bool initialState = true; diff --git a/bindings/cs/rl.net/MultiSlotResponseDetailed.cs b/bindings/cs/rl.net/MultiSlotResponseDetailed.cs index 002243a15..fff157c0e 100644 --- a/bindings/cs/rl.net/MultiSlotResponseDetailed.cs +++ b/bindings/cs/rl.net/MultiSlotResponseDetailed.cs @@ -13,13 +13,13 @@ namespace Native { internal partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateMultiSlotResponseDetailed(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteMultiSlotResponseDetailed(IntPtr multiSlotResponseDetailed); - [DllImport("rlnetnative", EntryPoint = "GetMultiSlotDetailedModelId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetMultiSlotDetailedModelId")] private static extern IntPtr GetMultiSlotDetailedModelIdNative(IntPtr multiSlotResponseDetailed); internal static Func GetMultiSlotDetailedModelIdOverride { get; set; } @@ -34,7 +34,7 @@ public static IntPtr GetMultiSlotDetailedModelId(IntPtr multiSlotResponseDetaile return GetMultiSlotDetailedModelIdNative(multiSlotResponseDetailed); } - [DllImport("rlnetnative", EntryPoint = "GetMultiSlotDetailedEventID")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetMultiSlotDetailedEventID")] private static extern IntPtr GetMultiSlotDetailedEventIDNative(IntPtr multiSlotResponseDetailed); internal static Func GetMultiSlotDetailedEventIdOverride { get; set; } @@ -51,7 +51,7 @@ public static IntPtr GetMultiSlotDetailedEventId(IntPtr multiSlotResponseDetaile // TODO: CLS-compliance requires that we not publically expose unsigned types. // Probably not a big issue ("9e18 actions ought to be enough for anyone...") - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern UIntPtr GetMultiSlotDetailedSize(IntPtr multiSlot); } } @@ -109,7 +109,7 @@ IEnumerator IEnumerable.GetEnumerator() private class MultiSlotResponseDetailedEnumerator : NativeObject, IEnumerator { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateMultiSlotDetailedEnumeratorAdapter(IntPtr multiSlotResponseDetailed); private static New BindConstructorArguments(MultiSlotResponseDetailed multiSlotResponseDetailed) @@ -124,16 +124,16 @@ private static New BindConstructorArguments }); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteMultiSlotDetailedEnumeratorAdapter(IntPtr multiSlotDetailedEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int MultiSlotDetailedEnumeratorInit(IntPtr multiSlotDetailedEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int MultiSlotDetailedEnumeratorMoveNext(IntPtr multiSlotDetailedEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr GetMultiSlotDetailedEnumeratorCurrent(IntPtr multiSlotDetailedEnumeratorAdapter); private bool initialState = true; diff --git a/bindings/cs/rl.net/Native/Global.cs b/bindings/cs/rl.net/Native/Global.cs index 92e181779..3d4c0932e 100644 --- a/bindings/cs/rl.net/Native/Global.cs +++ b/bindings/cs/rl.net/Native/Global.cs @@ -22,7 +22,7 @@ public static IntPtr ToNativeHandleOrNullptrDangerous(this NativeObject return nativeObject.DangerousGetHandle(); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr LookupMessageForErrorCode(int error_code); public static string MarshalMessageForErrorCode(int error_code) diff --git a/bindings/cs/rl.net/Native/NativeImports.cs b/bindings/cs/rl.net/Native/NativeImports.cs new file mode 100644 index 000000000..648dc0fdc --- /dev/null +++ b/bindings/cs/rl.net/Native/NativeImports.cs @@ -0,0 +1,9 @@ +namespace Rl.Net.Native { + internal static class NativeImports { + #if DEBUG + internal const string RLNETNATIVE = "rlnetnatived"; + #else + internal const string RLNETNATIVE = "rlnetnative"; + #endif + } +} diff --git a/bindings/cs/rl.net/NativeCallbacks.cs b/bindings/cs/rl.net/NativeCallbacks.cs index e40106be6..6a6e24971 100644 --- a/bindings/cs/rl.net/NativeCallbacks.cs +++ b/bindings/cs/rl.net/NativeCallbacks.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; namespace Rl.Net { @@ -8,6 +9,8 @@ internal static partial class NativeMethods { public delegate void managed_background_error_callback_t(IntPtr apiStatus); public delegate void managed_trace_callback_t(int logLevel, IntPtr msgUtf8Ptr); + public delegate int managed_oauth_callback_t(IntPtr scopes, IntPtr tokenOutPtr, IntPtr unixTimestamp); + public delegate void managed_oauth_callback_t_complete_t(IntPtr tokenStringToFree, int errorCode); } } } \ No newline at end of file diff --git a/bindings/cs/rl.net/OAuthCredentialProvider.cs b/bindings/cs/rl.net/OAuthCredentialProvider.cs new file mode 100644 index 000000000..bd21e0b93 --- /dev/null +++ b/bindings/cs/rl.net/OAuthCredentialProvider.cs @@ -0,0 +1,143 @@ +using System; +using System.ComponentModel; +using System.Runtime.InteropServices; +using System.Threading.Tasks; +using System.Collections.Generic; +using Rl.Net.Native; + +namespace Rl.Net +{ + namespace Native + { + internal static partial class NativeMethods + { + [DllImport(NativeImports.RLNETNATIVE)] + public static extern void RegisterDefaultFactoriesCallback( + [MarshalAs(UnmanagedType.FunctionPtr)] managed_oauth_callback_t callback, + [MarshalAs(UnmanagedType.FunctionPtr)] managed_oauth_callback_t_complete_t completion); + } + } + + public class OAuthTokenRequestedEventArgs : EventArgs + { + public IList Scopes { get; } + + public string Token { get; set; } + + public DateTime TokenExpirationTime { get; set; } + + public int ErrorCode { get; set; } + + public OAuthTokenRequestedEventArgs(IList scopes) + { + Scopes = scopes; + } + } + + /// + /// OAuthCredentialProvider is used to provide OAuth tokens to the underlying rlclientlib library. + /// + /// + /// OAuthCredentialProvider is currently static and should be sufficient for most applications. + /// To control the OAuth token retrieval process, subscribe to the OAuthTokenRequested event. + /// Use scopes if you need to request different tokens for different scopes which should be handled + /// in the event handler. + /// + /// Example Usage: + /// + /// static class EntryPoints + /// { + /// private static ManagedIdentityCredential managedIdentityCredential; + /// + /// private static void OnOAuthTokenRequest(object src, OAuthTokenRequestedEventArgs e) + /// { + /// try { + /// var accessToken = managedIdentityCredential.GetToken(e.Scopes); + /// e.Token = accessToken.Token + /// e.TokenExpirationTime = accessToken.ExpiresOn; + /// e.ErrorCode = 0; + /// } + /// catch (AuthenticationFailedException) { + /// e.ErrorCode = 4; // http_bad_status_code + /// } + /// } + /// + /// public static void Main(string[] args) + /// { + /// // some startup code is here; do not use anything that would invoke an OAuth token request + /// var my_client_id = args[1]; + /// managedIdentityCredential = new ManagedIdentityCredential(my_client_id); + /// // ... + /// // setup the OAuthCredentialProvider + /// OAuthCredentialProvider.OAuthTokenRequested += OnOAuthTokenRequest; + /// // now OAuth requests are ready to be handled + /// } + /// } + /// + /// + public static class OAuthCredentialProvider + { + private static readonly NativeMethods.managed_oauth_callback_t oauthCredentialCallback; + private static readonly NativeMethods.managed_oauth_callback_t_complete_t oauthCredentialCallbackCompletion; + + static OAuthCredentialProvider() + { + oauthCredentialCallback = new NativeMethods.managed_oauth_callback_t(WrapOAuthCredentialCallback); + oauthCredentialCallbackCompletion = new NativeMethods.managed_oauth_callback_t_complete_t(WrapOAuthCredentialCallbackCompletion); + NativeMethods.RegisterDefaultFactoriesCallback(oauthCredentialCallback, oauthCredentialCallbackCompletion); + } + + public static event EventHandler OAuthTokenRequested; + + private static int WrapOAuthCredentialCallback(IntPtr scopes, IntPtr tokenOutPtr, IntPtr expiryUnixTime) + { + var scopesArray = StringArrayFromNativeUtf8Strings(scopes); + var e = new OAuthTokenRequestedEventArgs(scopesArray); + OAuthTokenRequested?.Invoke(null, e); + if (e.ErrorCode == 0) + { + var outPtr = CreateUnmanagedString(e.Token ?? ""); + Marshal.WriteIntPtr(tokenOutPtr, outPtr); + Marshal.WriteInt64(expiryUnixTime, (new DateTimeOffset(e.TokenExpirationTime)).ToUnixTimeSeconds()); + } + return e.ErrorCode; + } + + private static void WrapOAuthCredentialCallbackCompletion(IntPtr tokenStringToFree, int _) + { + if (tokenStringToFree != IntPtr.Zero) + { + Marshal.FreeHGlobal(tokenStringToFree); + } + } + + private static IList StringArrayFromNativeUtf8Strings(IntPtr nativeStrings) + { + if (nativeStrings == IntPtr.Zero) + { + throw new ArgumentNullException(nameof(nativeStrings)); + } + var list = new List(); + IntPtr strPtr = Marshal.ReadIntPtr(nativeStrings); + for (int i = 1; strPtr != IntPtr.Zero; ++i) + { + list.Add(Marshal.PtrToStringAnsi(strPtr)); + strPtr = Marshal.ReadIntPtr(nativeStrings, i * IntPtr.Size); + } + return list; + } + + private static IntPtr CreateUnmanagedString(string str) + { + if (str == null) + { + throw new ArgumentNullException(nameof(str)); + } + byte[] bytes = System.Text.Encoding.UTF8.GetBytes(str); + IntPtr unmanagedString = Marshal.AllocHGlobal(bytes.Length + 1); + Marshal.Copy(bytes, 0, unmanagedString, bytes.Length); + Marshal.WriteByte(unmanagedString, bytes.Length, 0); + return unmanagedString; + } + } +} \ No newline at end of file diff --git a/bindings/cs/rl.net/RankingResponse.cs b/bindings/cs/rl.net/RankingResponse.cs index 7cfd59db5..e7bc3c30b 100644 --- a/bindings/cs/rl.net/RankingResponse.cs +++ b/bindings/cs/rl.net/RankingResponse.cs @@ -12,13 +12,13 @@ namespace Native { internal partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateRankingResponse(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteRankingResponse(IntPtr rankingResponse); - [DllImport("rlnetnative", EntryPoint = "GetRankingEventId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetRankingEventId")] private static extern IntPtr GetRankingEventIdNative(IntPtr rankingResponse); internal static Func GetRankingEventIdOverride { get; set; } @@ -33,7 +33,7 @@ public static IntPtr GetRankingEventId(IntPtr rankingResponse) return GetRankingEventIdNative(rankingResponse); } - [DllImport("rlnetnative", EntryPoint = "GetRankingModelId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetRankingModelId")] private static extern IntPtr GetRankingModelIdNative(IntPtr rankingResponse); internal static Func GetRankingModelIdOverride { get; set; } @@ -50,10 +50,10 @@ public static IntPtr GetRankingModelId(IntPtr rankingResponse) // TODO: CLS-compliance requires that we not publically expose unsigned types. // Probably not a big issue ("9e18 actions ought to be enough for anyone...") - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern UIntPtr GetRankingActionCount(IntPtr rankingResponse); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int GetRankingChosenAction(IntPtr rankingResponse, out UIntPtr action_id, IntPtr status); } } @@ -162,7 +162,7 @@ IEnumerator IEnumerable.GetEnumerator() private class RankingResponseEnumerator : NativeObject, IEnumerator { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateRankingEnumeratorAdapter(IntPtr rankingResponse); private static New BindConstructorArguments(RankingResponse rankingResponse) @@ -176,16 +176,16 @@ private static New BindConstructorArguments(RankingRe }); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteRankingEnumeratorAdapter(IntPtr rankingEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int RankingEnumeratorInit(IntPtr rankingEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int RankingEnumeratorMoveNext(IntPtr rankingEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern ActionProbability GetRankingEnumeratorCurrent(IntPtr rankingEnumeratorAdapter); private bool initialState = true; diff --git a/bindings/cs/rl.net/SharedBuffer.cs b/bindings/cs/rl.net/SharedBuffer.cs index 0183f95e4..f78cb0c0b 100644 --- a/bindings/cs/rl.net/SharedBuffer.cs +++ b/bindings/cs/rl.net/SharedBuffer.cs @@ -5,16 +5,16 @@ namespace Rl.Net { public sealed class SharedBuffer : NativeObject { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CloneBufferSharedPointer(IntPtr original); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void ReleaseBufferSharedPointer(IntPtr shared_buffer); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr GetSharedBufferBegin(IntPtr shared_buffer); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern UIntPtr GetSharedBufferLength(IntPtr status); private static New BindConstructorArguments(SharedBuffer original) diff --git a/bindings/cs/rl.net/SlatesLoop.cs b/bindings/cs/rl.net/SlatesLoop.cs index d3cce6fd8..93efb748d 100644 --- a/bindings/cs/rl.net/SlatesLoop.cs +++ b/bindings/cs/rl.net/SlatesLoop.cs @@ -10,16 +10,16 @@ namespace Native // The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation internal static partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateSlatesLoop(IntPtr config, IntPtr factoryContext); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteSlatesLoop(IntPtr slatesLoop); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int SlatesLoopInit(IntPtr slatesLoop, IntPtr apiStatus); - [DllImport("rlnetnative", EntryPoint = "SlatesLoopRequestMultiSlotDecision")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopRequestMultiSlotDecision")] private static extern int SlatesLoopRequestMultiSlotDecisionNative(IntPtr slatesLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponse, IntPtr apiStatus); internal static Func SlatesLoopRequestMultiSlotDecisionOverride { get; set; } @@ -34,7 +34,7 @@ public static int SlatesLoopRequestMultiSlotDecision(IntPtr slatesLoop, IntPtr e return SlatesLoopRequestMultiSlotDecisionNative(slatesLoop, eventId, contextJson, contextJsonSize, multiSlotResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopRequestMultiSlotDecisionWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopRequestMultiSlotDecisionWithFlags")] private static extern int SlatesLoopRequestMultiSlotDecisionWithFlagsNative(IntPtr slatesLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr apiStatus); internal static Func SlatesLoopRequestMultiSlotDecisionWithFlagsOverride { get; set; } @@ -49,7 +49,7 @@ public static int SlatesLoopRequestMultiSlotDecisionWithFlags(IntPtr slatesLoop, return SlatesLoopRequestMultiSlotDecisionWithFlagsNative(slatesLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopRequestMultiSlotDecisionWithBaselineAndFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopRequestMultiSlotDecisionWithBaselineAndFlags")] private static extern int SlatesLoopRequestMultiSlotDecisionWithBaselineAndFlagsNative(IntPtr slatesLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponse, IntPtr baselineActions, IntPtr baselineActionsSize, IntPtr apiStatus); internal static Func SlatesLoopRequestMultiSlotDecisionWithBaselineAndFlagsOverride { get; set; } @@ -64,7 +64,7 @@ public static int SlatesLoopRequestMultiSlotDecisionWithBaselineAndFlags(IntPtr return SlatesLoopRequestMultiSlotDecisionWithBaselineAndFlagsNative(slatesLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponse, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopRequestMultiSlotDecisionDetailed")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopRequestMultiSlotDecisionDetailed")] private static extern int SlatesLoopRequestMultiSlotDecisionDetailedNative(IntPtr slatesLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); internal static Func SlatesLoopRequestMultiSlotDecisionDetailedOverride { get; set; } @@ -79,7 +79,7 @@ public static int SlatesLoopRequestMultiSlotDecisionDetailed(IntPtr slatesLoop, return SlatesLoopRequestMultiSlotDecisionDetailedNative(slatesLoop, eventId, contextJson, contextJsonSize, multiSlotResponseDetailed, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopRequestMultiSlotDecisionDetailedWithFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopRequestMultiSlotDecisionDetailedWithFlags")] private static extern int SlatesLoopRequestMultiSlotDecisionDetailedWithFlagsNative(IntPtr slatesLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr apiStatus); internal static Func SlatesLoopRequestMultiSlotDecisionDetailedWithFlagsOverride { get; set; } @@ -94,7 +94,7 @@ public static int SlatesLoopRequestMultiSlotDecisionDetailedWithFlags(IntPtr sla return SlatesLoopRequestMultiSlotDecisionDetailedWithFlagsNative(slatesLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponseDetailed, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlags")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlags")] private static extern int SlatesLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlagsNative(IntPtr slatesLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr multiSlotResponseDetailed, IntPtr baselineActions, IntPtr baselineActionsSize, IntPtr apiStatus); internal static Func SlatesLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlagsOverride { get; set; } @@ -109,7 +109,7 @@ public static int SlatesLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlags return SlatesLoopRequestMultiSlotDecisionDetailedWithBaselineAndFlagsNative(slatesLoop, eventId, contextJson, contextJsonSize, flags, multiSlotResponseDetailed, baselineActions, baselineActionsSize, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopReportActionTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopReportActionTaken")] private static extern int SlatesLoopReportActionTakenNative(IntPtr slatesLoop, IntPtr eventId, IntPtr apiStatus); internal static Func SlatesLoopReportActionTakenOverride { get; set; } @@ -124,7 +124,7 @@ public static int SlatesLoopReportActionTaken(IntPtr slatesLoop, IntPtr eventId, return SlatesLoopReportActionTakenNative(slatesLoop, eventId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopReportActionMultiIdTaken")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopReportActionMultiIdTaken")] private static extern int SlatesLoopReportActionTakenMultiIdNative(IntPtr slatesLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus); internal static Func SlatesLoopReportActionTakenMultiIdOverride { get; set; } @@ -139,7 +139,7 @@ public static int SlatesLoopReportActionMultiIdTaken(IntPtr slatesLoop, IntPtr p return SlatesLoopReportActionTakenMultiIdNative(slatesLoop, primaryId, secondaryId, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopReportOutcomeF")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopReportOutcomeF")] private static extern int SlatesLoopReportOutcomeFNative(IntPtr slatesLoop, IntPtr eventId, float outcome, IntPtr apiStatus); internal static Func SlatesLoopReportOutcomeFOverride { get; set; } @@ -154,7 +154,7 @@ public static int SlatesLoopReportOutcomeF(IntPtr slatesLoop, IntPtr eventId, fl return SlatesLoopReportOutcomeFNative(slatesLoop, eventId, outcome, apiStatus); } - [DllImport("rlnetnative", EntryPoint = "SlatesLoopReportOutcomeJson")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "SlatesLoopReportOutcomeJson")] private static extern int SlatesLoopReportOutcomeJsonNative(IntPtr slatesLoop, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus); internal static Func SlatesLoopReportOutcomeJsonOverride { get; set; } @@ -169,13 +169,13 @@ public static int SlatesLoopReportOutcomeJson(IntPtr slatesLoop, IntPtr eventId, return SlatesLoopReportOutcomeJsonNative(slatesLoop, eventId, outcomeJson, apiStatus); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int SlatesLoopRefreshModel(IntPtr slatesLoop, IntPtr apiStatus); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void SlatesLoopSetCallback(IntPtr slatesLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void SlatesLoopSetTrace(IntPtr slatesLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null); } } diff --git a/bindings/cs/rl.net/SlotRanking.cs b/bindings/cs/rl.net/SlotRanking.cs index 10bd2aa68..221e77a03 100644 --- a/bindings/cs/rl.net/SlotRanking.cs +++ b/bindings/cs/rl.net/SlotRanking.cs @@ -13,13 +13,13 @@ namespace Native { internal partial class NativeMethods { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern IntPtr CreateSlotRanking(); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern void DeleteSlotRanking(IntPtr slot); - [DllImport("rlnetnative", EntryPoint = "GetSlotId")] + [DllImport(NativeImports.RLNETNATIVE, EntryPoint = "GetSlotId")] private static extern IntPtr GetSlotIdNative(IntPtr slotRanking); internal static Func GetSlotIdOverride { get; set; } @@ -34,10 +34,10 @@ public static IntPtr GetSlotId(IntPtr slotRanking) return GetSlotIdNative(slotRanking); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern UIntPtr GetSlotActionCount(IntPtr slotRanking); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] public static extern int GetSlotChosenAction(IntPtr slotRanking, out UIntPtr action_id, IntPtr status); } } @@ -122,7 +122,7 @@ IEnumerator IEnumerable.GetEnumerator() private class SlotRankingEnumerator : NativeObject, IEnumerator { - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern IntPtr CreateSlotEnumeratorAdapter(IntPtr slotResponse); private static New BindConstructorArguments(SlotRanking slotRanking) @@ -136,16 +136,16 @@ private static New BindConstructorArguments(SlotRanking s }); } - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern void DeleteSlotEnumeratorAdapter(IntPtr slotEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int SlotEnumeratorInit(IntPtr slotEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern int SlotEnumeratorMoveNext(IntPtr slotEnumeratorAdapter); - [DllImport("rlnetnative")] + [DllImport(NativeImports.RLNETNATIVE)] private static extern ActionProbability GetSlotEnumeratorCurrent(IntPtr slotEnumeratorAdapter); private bool initialState = true; diff --git a/examples/rl_sim_cpp/CMakeLists.txt b/examples/rl_sim_cpp/CMakeLists.txt index 39a3ce35e..27018ba62 100644 --- a/examples/rl_sim_cpp/CMakeLists.txt +++ b/examples/rl_sim_cpp/CMakeLists.txt @@ -6,11 +6,6 @@ set(RL_SIM_SOURCES robot_joint.cc rl_sim.cc ) -if(RL_LINK_AZURE_LIBS) - list(APPEND RL_SIM_SOURCES - azure_credentials.cc - ) -endif() add_executable(rl_sim_cpp.out ${RL_SIM_SOURCES} diff --git a/examples/rl_sim_cpp/azure_credentials.cc b/examples/rl_sim_cpp/azure_credentials.cc deleted file mode 100644 index a2affffd1..000000000 --- a/examples/rl_sim_cpp/azure_credentials.cc +++ /dev/null @@ -1,67 +0,0 @@ -#ifdef LINK_AZURE_LIBS -# include "azure_credentials.h" - -# include "err_constants.h" -# include "future_compat.h" - -# include -# include -// These are needed because azure does a bad time conversion -# include -# include -# include -# include - -using namespace reinforcement_learning; - -AzureCredentials::AzureCredentials(const std::string& tenant_id) : _tenant_id(tenant_id), _creds(create_options()) {} - -Azure::Identity::AzureCliCredentialOptions AzureCredentials::create_options() -{ - Azure::Identity::AzureCliCredentialOptions options; - options.TenantId = _tenant_id; - options.AdditionallyAllowedTenants.push_back("*"); - return options; -} - -int AzureCredentials::get_credentials( - const std::vector& scopes, std::string& token_out, std::chrono::system_clock::time_point& expiry_out) -{ -# ifdef HAS_STD14 - Azure::Core::Credentials::TokenRequestContext request_context; - request_context.Scopes = scopes; - // TODO: needed? - request_context.TenantId = _tenant_id; - Azure::Core::Context context; - try - { - auto auth = _creds.GetToken(request_context, context); - token_out = auth.Token; - - // Casting from an azure DateTime object to a time_point does the calculation - // incorrectly. The expiration is returned as a local time, but the library - // assumes that it is GMT, and converts the value incorrectly. - // See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075 - // expiry_out = static_cast(auth.ExpiresOn); - std::string dt_string = auth.ExpiresOn.ToString(); - std::tm tm = {}; - std::istringstream ss(dt_string); - ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); - expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm)); - } - catch (std::exception& e) - { - std::cout << "Error getting auth token: " << e.what(); - return error_code::external_error; - } - catch (...) - { - std::cout << "Unknown error while getting auth token"; - return error_code::external_error; - } -# else -# error Requires C++14 or greater -# endif - return error_code::success; -} -#endif \ No newline at end of file diff --git a/examples/rl_sim_cpp/azure_credentials.h b/examples/rl_sim_cpp/azure_credentials.h deleted file mode 100644 index 30c2256e0..000000000 --- a/examples/rl_sim_cpp/azure_credentials.h +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#ifdef LINK_AZURE_LIBS -# include "api_status.h" -# include "configuration.h" -# include "future_compat.h" - -# include -# include -# include -# include -# include - -class AzureCredentials -{ -public: - AzureCredentials(const std::string& tenant_id); - int get_credentials(const std::vector& scopes, std::string& token_out, - std::chrono::system_clock::time_point& expiry_out); - -private: - std::string _tenant_id; -# ifdef HAS_STD14 - Azure::Identity::AzureCliCredentialOptions create_options(); - - // Azure::Identity::DefaultAzureCredential _creds; - Azure::Identity::AzureCliCredential _creds; -# endif -}; -#endif \ No newline at end of file diff --git a/examples/rl_sim_cpp/main.cc b/examples/rl_sim_cpp/main.cc index 47eb72e1c..324d6051e 100644 --- a/examples/rl_sim_cpp/main.cc +++ b/examples/rl_sim_cpp/main.cc @@ -39,8 +39,7 @@ po::variables_map process_cmd_line(const int argc, char** argv) "delay", po::value()->default_value(2000), "Delay between events in ms")( "quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value()->default_value(true), "Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")( - "azure_oauth_factories", po::value()->default_value(false), "Use oauth for azure factores. Default false")( - "azure_tenant_id", po::value()->default_value(""), "Tenant ID for use with azure oauth factories."); + "azure_oauth_factories", po::value()->default_value(false), "Use oauth for azure factores. Default false"); po::variables_map vm; store(parse_command_line(argc, argv, desc), vm); diff --git a/examples/rl_sim_cpp/rl_sim.cc b/examples/rl_sim_cpp/rl_sim.cc index 0cb4be798..e992a0234 100644 --- a/examples/rl_sim_cpp/rl_sim.cc +++ b/examples/rl_sim_cpp/rl_sim.cc @@ -496,7 +496,7 @@ int rl_sim::init_rl() // Note: This requires C++14 or better using namespace std::placeholders; reinforcement_learning::oauth_callback_t callback = - std::bind(&AzureCredentials::get_credentials, &_creds, _1, _2, _3); + std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3); reinforcement_learning::register_default_factories_callback(callback); #endif } @@ -653,10 +653,8 @@ std::string rl_sim::create_event_id() rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)) , _loop_kind(CB) -#ifdef LINK_AZURE_LIBS - , _creds(_options["azure_tenant_id"].as()) -#endif { + if (_options["ccb"].as()) { _loop_kind = CCB; } else if (_options["slates"].as()) { _loop_kind = Slates; } else if (_options["ca"].as()) { _loop_kind = CA; } diff --git a/examples/rl_sim_cpp/rl_sim.h b/examples/rl_sim_cpp/rl_sim.h index 5cd68bd92..40699bcac 100644 --- a/examples/rl_sim_cpp/rl_sim.h +++ b/examples/rl_sim_cpp/rl_sim.h @@ -6,7 +6,12 @@ * @date 2018-07-18 */ #pragma once -#include "azure_credentials.h" +#include "azure_credentials_provider.h" + +#ifdef LINK_AZURE_LIBS +#include +#endif + #include "live_model.h" #include "person.h" #include "robot_joint.h" @@ -179,6 +184,8 @@ class rl_sim bool _quiet = false; bool _random_ids = true; #ifdef LINK_AZURE_LIBS - AzureCredentials _creds; + using azure_cred_t = Azure::Identity::DefaultAzureCredential; + using azure_credentials_provider_t = reinforcement_learning::azure_credentials_provider; + azure_credentials_provider_t _creds; #endif }; diff --git a/include/azure_credentials_provider.h b/include/azure_credentials_provider.h new file mode 100644 index 000000000..12e7cbbd4 --- /dev/null +++ b/include/azure_credentials_provider.h @@ -0,0 +1,80 @@ +#pragma once + +#ifdef LINK_AZURE_LIBS + +#include "api_status.h" +#include "configuration.h" +#include "future_compat.h" + +#include "err_constants.h" +#include "future_compat.h" + +#include +#include +#include +// These are needed because azure does a bad time conversion +#include +#include +#include +#include + +namespace reinforcement_learning +{ + +template +class azure_credentials_provider +{ +public: + template + azure_credentials_provider(Args&&... args) : + _creds(std::make_unique(std::forward(args)...)) {} + + int get_credentials(const std::vector& scopes, std::string& token_out, + std::chrono::system_clock::time_point& expiry_out) + { + try + { + Azure::Core::Credentials::TokenRequestContext request_context; + request_context.Scopes = scopes; + + Azure::Core::Context context; + std::cout << "fetching token for " << scopes[0] << std::endl; + + std::lock_guard lock(_creds_mtx); + auto auth = _creds->GetToken(request_context, context); + token_out = auth.Token; + + // Casting from an azure DateTime object to a time_point does the calculation + // incorrectly. The expiration is returned as a local time, but the library + // assumes that it is GMT, and converts the value incorrectly. + // See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075 + // expiry_out = static_cast(auth.ExpiresOn); + std::string dt_string = auth.ExpiresOn.ToString(); + std::tm tm = {}; + std::istringstream ss(dt_string); + ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); + expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + + return 0; + } + catch (std::exception& e) + { + std::cout << "Error getting auth token: " << e.what(); + return error_code::external_error; + } + catch (...) + { + std::cout << "Unknown error while getting auth token"; + return error_code::external_error; + } + return error_code::success; + } + +private: + std::unique_ptr _creds; + mutable std::mutex _creds_mtx; +}; + +} // namespace reinforcement_learning + +#endif diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index c7b3236b5..cb37705d8 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -118,6 +118,7 @@ endif() set(PROJECT_PUBLIC_HEADERS ../include/action_flags.h ../include/api_status.h + ../include/azure_credentials_provider.h ../include/loop_apis/base_loop.h ../include/loop_apis/ca_loop.h ../include/loop_apis/cb_loop.h From 484c5e5d5c3d7c38997d7fb097d5b14ce5707cab Mon Sep 17 00:00:00 2001 From: James Longo Date: Fri, 9 Aug 2024 09:18:10 -0400 Subject: [PATCH 02/19] added memory and azure include to the azure_credentials_provider --- include/azure_credentials_provider.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/azure_credentials_provider.h b/include/azure_credentials_provider.h index 12e7cbbd4..4f7e91672 100644 --- a/include/azure_credentials_provider.h +++ b/include/azure_credentials_provider.h @@ -17,6 +17,11 @@ #include #include #include +#include + +# ifdef LINK_AZURE_LIBS +# include +# endif namespace reinforcement_learning { From cd6e87ad051a57e311368bf7651838670fad7465 Mon Sep 17 00:00:00 2001 From: James Longo Date: Fri, 9 Aug 2024 14:40:36 -0400 Subject: [PATCH 03/19] added trace logger to the azure credential providers added 2 new error code for failing to authenticate updated items as per PR comments --- .../rl.net.native/rl.net.azure_factories.cc | 37 ++++++---- .../cs/rl.net.native/rl.net.azure_factories.h | 2 + examples/rl_sim_cpp/rl_sim.cc | 2 +- include/azure_credentials_provider.h | 72 ++++++++++--------- include/errors_data.h | 2 + include/oauth_callback_fn.h | 3 +- rlclientlib/utility/api_header_token.h | 2 +- 7 files changed, 69 insertions(+), 51 deletions(-) diff --git a/bindings/cs/rl.net.native/rl.net.azure_factories.cc b/bindings/cs/rl.net.native/rl.net.azure_factories.cc index d850e4922..dcd7aa79f 100644 --- a/bindings/cs/rl.net.native/rl.net.azure_factories.cc +++ b/bindings/cs/rl.net.native/rl.net.azure_factories.cc @@ -6,30 +6,38 @@ rl_net_native::azure_factory_oauth_callback_t g_oauth_callback = nullptr; rl_net_native::azure_factory_oauth_callback_complete_t g_oauth_callback_complete = nullptr; static int azure_factory_oauth_callback(const std::vector& scopes, - std::string& oauth_token, std::chrono::system_clock::time_point& token_expiry) + std::string& oauth_token, std::chrono::system_clock::time_point& token_expiry, + reinforcement_learning::i_trace *trace) { - if (g_oauth_callback == nullptr) { - return -1; + if (g_oauth_callback == nullptr) + { + return reinforcement_learning::error_code::invalid_argument; } - // create a null terminated array of scope string pointers - // these are pointers are readonly and owned by the caller - std::vector native_scopes; - native_scopes.reserve(scopes.size() + 1); + // create a null terminated array of scope string pointers. + // these are pointers are readonly and owned by the caller. + // since we create the vector of n elements, we can guarantee + // the last element is null. + std::vector native_scopes(scopes.size() + 1); for (int i = 0; i < scopes.size(); ++i) { - native_scopes.push_back(scopes[i].c_str()); + native_scopes[i] = scopes[i].c_str(); } - native_scopes.push_back(nullptr); - // we expect to get a pointer to a null terminated string + // we expect to get a pointer to a null terminated string. // it's allocated on the managed heap, so we can't free it here - // instead, we will pass it back to the managed code after we copy it + // instead, we will pass it back to the managed code after we copy it. char *oauth_token_ptr = nullptr; std::int64_t expiryUnixTime = 0; auto ret = g_oauth_callback(native_scopes.data(), &oauth_token_ptr, &expiryUnixTime); - if (ret == reinforcement_learning::error_code::success) { + if (ret == reinforcement_learning::error_code::success) + { + TRACE_DEBUG(trace, "rl_net_native::azure_factory_oauth_callback: successfully retrieved token"); oauth_token = oauth_token_ptr; token_expiry = std::chrono::system_clock::from_time_t(expiryUnixTime); } + else + { + TRACE_ERROR(trace, "rl_net_native::azure_factory_oauth_callback: failed to retrieve token"); + } g_oauth_callback_complete(oauth_token_ptr, reinforcement_learning::error_code::success); return ret; } @@ -38,10 +46,11 @@ static int azure_factory_oauth_callback(const std::vector& scopes, API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback, rl_net_native::azure_factory_oauth_callback_complete_t completion) { - if (rl_net_native::g_oauth_callback == nullptr) { + if (rl_net_native::g_oauth_callback == nullptr) + { rl_net_native::g_oauth_callback = callback; reinforcement_learning::register_default_factories_callback( - reinforcement_learning::oauth_callback_t{rl_net_native::azure_factory_oauth_callback}); + reinforcement_learning::oauth_callback_t { rl_net_native::azure_factory_oauth_callback }); rl_net_native::g_oauth_callback_complete = completion; } } diff --git a/bindings/cs/rl.net.native/rl.net.azure_factories.h b/bindings/cs/rl.net.native/rl.net.azure_factories.h index 1ae564c14..4812cbc4c 100644 --- a/bindings/cs/rl.net.native/rl.net.azure_factories.h +++ b/bindings/cs/rl.net.native/rl.net.azure_factories.h @@ -17,6 +17,8 @@ typedef int (*azure_factory_oauth_callback_t)(const char**, char**, std::int64_t // Callback function to be called by the C# code to signal the completion of the OAuth token request // this provides the C# code with the opportunity to free the memory allocated for the token +// tokenStringToFree - pointer to the token string than needs to be freed +// errorCode - for future use typedef void (*azure_factory_oauth_callback_complete_t)(char*, int); } // namespace rl_net_native diff --git a/examples/rl_sim_cpp/rl_sim.cc b/examples/rl_sim_cpp/rl_sim.cc index e992a0234..fdd900458 100644 --- a/examples/rl_sim_cpp/rl_sim.cc +++ b/examples/rl_sim_cpp/rl_sim.cc @@ -496,7 +496,7 @@ int rl_sim::init_rl() // Note: This requires C++14 or better using namespace std::placeholders; reinforcement_learning::oauth_callback_t callback = - std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3); + std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4); reinforcement_learning::register_default_factories_callback(callback); #endif } diff --git a/include/azure_credentials_provider.h b/include/azure_credentials_provider.h index 4f7e91672..676889e83 100644 --- a/include/azure_credentials_provider.h +++ b/include/azure_credentials_provider.h @@ -2,14 +2,6 @@ #ifdef LINK_AZURE_LIBS -#include "api_status.h" -#include "configuration.h" -#include "future_compat.h" - -#include "err_constants.h" -#include "future_compat.h" - -#include #include #include // These are needed because azure does a bad time conversion @@ -19,34 +11,44 @@ #include #include -# ifdef LINK_AZURE_LIBS -# include -# endif +#include + +#ifdef LINK_AZURE_LIBS +#include +#endif + +#include "err_constants.h" +#include "trace_logger.h" namespace reinforcement_learning { - template class azure_credentials_provider { public: template azure_credentials_provider(Args&&... args) : - _creds(std::make_unique(std::forward(args)...)) {} + _creds(std::make_unique(std::forward(args)...)) + { + } int get_credentials(const std::vector& scopes, std::string& token_out, - std::chrono::system_clock::time_point& expiry_out) + std::chrono::system_clock::time_point& expiry_out, i_trace* trace) { + using namespace Azure::Core; + using namespace Azure::Core::Credentials; try { - Azure::Core::Credentials::TokenRequestContext request_context; - request_context.Scopes = scopes; + TokenRequestContext request_context; + Context context; - Azure::Core::Context context; - std::cout << "fetching token for " << scopes[0] << std::endl; - - std::lock_guard lock(_creds_mtx); - auto auth = _creds->GetToken(request_context, context); + request_context.Scopes = scopes; + AccessToken auth; + { + std::lock_guard lock(_creds_mtx); + auth = _creds->GetToken(request_context, context); + TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token"); + } token_out = auth.Token; // Casting from an azure DateTime object to a time_point does the calculation @@ -59,19 +61,22 @@ class azure_credentials_provider std::istringstream ss(dt_string); ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm)); - - return 0; - } - catch (std::exception& e) - { - std::cout << "Error getting auth token: " << e.what(); - return error_code::external_error; - } - catch (...) - { - std::cout << "Unknown error while getting auth token"; - return error_code::external_error; } + catch (AuthenticationException& e) + { + TRACE_ERROR(trace, e.what()); + return error_code::http_oauth_authentication_error; + } + catch (std::exception& e) + { + TRACE_ERROR(trace, e.what()); + return error_code::http_oauth_unexpected_error; + } + catch (...) + { + TRACE_ERROR(trace, "azure_credentials_provider: an unexpected unknown error occurred"); + return error_code::http_oauth_unexpected_error; + } return error_code::success; } @@ -79,7 +84,6 @@ class azure_credentials_provider std::unique_ptr _creds; mutable std::mutex _creds_mtx; }; - } // namespace reinforcement_learning #endif diff --git a/include/errors_data.h b/include/errors_data.h index 791f04913..6d2c15e43 100644 --- a/include/errors_data.h +++ b/include/errors_data.h @@ -55,4 +55,6 @@ ERROR_CODE_DEFINITION(49, baseline_actions_not_defined, "Baseline Actions must b ERROR_CODE_DEFINITION(50, http_api_key_not_provided, "Http api key must be provided") ERROR_CODE_DEFINITION(51, http_model_uri_not_provided, "Model Blob URI parameter was not passed in via configuration") ERROR_CODE_DEFINITION(52, static_model_load_error, "Static model passed in C# layer is not loading properly") +ERROR_CODE_DEFINITION(53, http_oauth_authentication_error, "http request failed to authenticate") +ERROR_CODE_DEFINITION(54, http_oauth_unexpected_error, "http request failed with an unexpected error while retrieving a token") //! [Error Definitions] diff --git a/include/oauth_callback_fn.h b/include/oauth_callback_fn.h index 80fdaddaa..aabe1c3bd 100644 --- a/include/oauth_callback_fn.h +++ b/include/oauth_callback_fn.h @@ -4,9 +4,10 @@ #include #include #include +#include "trace_logger.h" namespace reinforcement_learning { using oauth_callback_t = - std::function&, std::string&, std::chrono::system_clock::time_point&)>; + std::function&, std::string&, std::chrono::system_clock::time_point&, i_trace* trace)>; } \ No newline at end of file diff --git a/rlclientlib/utility/api_header_token.h b/rlclientlib/utility/api_header_token.h index 82ff74491..8edc8d457 100644 --- a/rlclientlib/utility/api_header_token.h +++ b/rlclientlib/utility/api_header_token.h @@ -107,7 +107,7 @@ class api_header_token_callback { using namespace std::chrono; system_clock::time_point tp; - RETURN_IF_FAIL(_token_callback(_scopes, _bearer_token, _token_expiry)); + RETURN_IF_FAIL(_token_callback(_scopes, _bearer_token, _token_expiry, trace)); if (_bearer_token.empty()) { From ebfcfe772e2abf5199db4d3f7b90da5454be1620 Mon Sep 17 00:00:00 2001 From: James Longo Date: Fri, 9 Aug 2024 14:43:58 -0400 Subject: [PATCH 04/19] removed trace from lock scope --- include/azure_credentials_provider.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/azure_credentials_provider.h b/include/azure_credentials_provider.h index 676889e83..81783f634 100644 --- a/include/azure_credentials_provider.h +++ b/include/azure_credentials_provider.h @@ -47,8 +47,8 @@ class azure_credentials_provider { std::lock_guard lock(_creds_mtx); auth = _creds->GetToken(request_context, context); - TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token"); } + TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token"); token_out = auth.Token; // Casting from an azure DateTime object to a time_point does the calculation From 14303e81d94c43853481607d562944e0ee02d435 Mon Sep 17 00:00:00 2001 From: James Longo Date: Fri, 9 Aug 2024 16:02:30 -0400 Subject: [PATCH 05/19] added set(CMAKE_DEBUG_POSTFIX "") to ensure VS builds have the same lib debug suffixes as the Ninja builds. update the rlnetnative import constant (removing the d suffix) --- CMakeLists.txt | 5 +++++ bindings/cs/rl.net/Native/NativeImports.cs | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a0aa41108..a12cc277a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,11 @@ if(POLICY CMP0091) cmake_policy(SET CMP0091 NEW) endif() +# ensure all of the build tools generate the same output on all platforms +# note: this change was made since building with Ninja does not add suffixes +# but, using the VS generator does. +set(CMAKE_DEBUG_POSTFIX "") + if(WIN32) # Due to needing to configure the CMAKE platform, this needs to be included before the # top-level project() declaration. diff --git a/bindings/cs/rl.net/Native/NativeImports.cs b/bindings/cs/rl.net/Native/NativeImports.cs index 648dc0fdc..d59656bdf 100644 --- a/bindings/cs/rl.net/Native/NativeImports.cs +++ b/bindings/cs/rl.net/Native/NativeImports.cs @@ -1,7 +1,9 @@ namespace Rl.Net.Native { internal static class NativeImports { + // NOTE: RLNETNATIVE for debug and release are the same, + // but this is a placeholder for future changes. #if DEBUG - internal const string RLNETNATIVE = "rlnetnatived"; + internal const string RLNETNATIVE = "rlnetnative"; #else internal const string RLNETNATIVE = "rlnetnative"; #endif From fc2fc17be21e56dfaa5b072294dd70cf5409ccc8 Mon Sep 17 00:00:00 2001 From: James Longo Date: Fri, 9 Aug 2024 16:46:14 -0400 Subject: [PATCH 06/19] updated formatting (according to tidy) --- .../rl.net.native/rl.net.azure_factories.cc | 28 +++------- .../cs/rl.net.native/rl.net.azure_factories.h | 10 ++-- examples/rl_sim_cpp/rl_sim.cc | 7 +-- examples/rl_sim_cpp/rl_sim.h | 2 +- include/azure_credentials_provider.h | 56 +++++++++---------- include/errors_data.h | 3 +- include/oauth_callback_fn.h | 7 ++- 7 files changed, 50 insertions(+), 63 deletions(-) diff --git a/bindings/cs/rl.net.native/rl.net.azure_factories.cc b/bindings/cs/rl.net.native/rl.net.azure_factories.cc index dcd7aa79f..20ddacc11 100644 --- a/bindings/cs/rl.net.native/rl.net.azure_factories.cc +++ b/bindings/cs/rl.net.native/rl.net.azure_factories.cc @@ -5,27 +5,20 @@ namespace rl_net_native rl_net_native::azure_factory_oauth_callback_t g_oauth_callback = nullptr; rl_net_native::azure_factory_oauth_callback_complete_t g_oauth_callback_complete = nullptr; -static int azure_factory_oauth_callback(const std::vector& scopes, - std::string& oauth_token, std::chrono::system_clock::time_point& token_expiry, - reinforcement_learning::i_trace *trace) +static int azure_factory_oauth_callback(const std::vector& scopes, std::string& oauth_token, + std::chrono::system_clock::time_point& token_expiry, reinforcement_learning::i_trace* trace) { - if (g_oauth_callback == nullptr) - { - return reinforcement_learning::error_code::invalid_argument; - } + if (g_oauth_callback == nullptr) { return reinforcement_learning::error_code::invalid_argument; } // create a null terminated array of scope string pointers. // these are pointers are readonly and owned by the caller. // since we create the vector of n elements, we can guarantee // the last element is null. std::vector native_scopes(scopes.size() + 1); - for (int i = 0; i < scopes.size(); ++i) - { - native_scopes[i] = scopes[i].c_str(); - } + for (int i = 0; i < scopes.size(); ++i) { native_scopes[i] = scopes[i].c_str(); } // we expect to get a pointer to a null terminated string. // it's allocated on the managed heap, so we can't free it here // instead, we will pass it back to the managed code after we copy it. - char *oauth_token_ptr = nullptr; + char* oauth_token_ptr = nullptr; std::int64_t expiryUnixTime = 0; auto ret = g_oauth_callback(native_scopes.data(), &oauth_token_ptr, &expiryUnixTime); if (ret == reinforcement_learning::error_code::success) @@ -34,23 +27,20 @@ static int azure_factory_oauth_callback(const std::vector& scopes, oauth_token = oauth_token_ptr; token_expiry = std::chrono::system_clock::from_time_t(expiryUnixTime); } - else - { - TRACE_ERROR(trace, "rl_net_native::azure_factory_oauth_callback: failed to retrieve token"); - } + else { TRACE_ERROR(trace, "rl_net_native::azure_factory_oauth_callback: failed to retrieve token"); } g_oauth_callback_complete(oauth_token_ptr, reinforcement_learning::error_code::success); return ret; } -} +} // namespace rl_net_native API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback, - rl_net_native::azure_factory_oauth_callback_complete_t completion) + rl_net_native::azure_factory_oauth_callback_complete_t completion) { if (rl_net_native::g_oauth_callback == nullptr) { rl_net_native::g_oauth_callback = callback; reinforcement_learning::register_default_factories_callback( - reinforcement_learning::oauth_callback_t { rl_net_native::azure_factory_oauth_callback }); + reinforcement_learning::oauth_callback_t{rl_net_native::azure_factory_oauth_callback}); rl_net_native::g_oauth_callback_complete = completion; } } diff --git a/bindings/cs/rl.net.native/rl.net.azure_factories.h b/bindings/cs/rl.net.native/rl.net.azure_factories.h index 4812cbc4c..10ba324e5 100644 --- a/bindings/cs/rl.net.native/rl.net.azure_factories.h +++ b/bindings/cs/rl.net.native/rl.net.azure_factories.h @@ -1,7 +1,7 @@ #pragma once -#include "rl.net.native.h" #include "factory_resolver.h" +#include "rl.net.native.h" #include @@ -28,8 +28,8 @@ extern "C" // both the callback and completion functions must be registered before any calls to the Azure factories are made // typically, this function should be called once during the initialization of the application // callback - the callback function to be called by the C++ code to get the OAuth token - // completion - the callback function to be called by the C++ code to signal the completion of the OAuth token request - API void RegisterDefaultFactoriesCallback( - rl_net_native::azure_factory_oauth_callback_t callback, - rl_net_native::azure_factory_oauth_callback_complete_t completion); + // completion - the callback function to be called by the C++ code to signal the completion of the OAuth token + // request + API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback, + rl_net_native::azure_factory_oauth_callback_complete_t completion); } diff --git a/examples/rl_sim_cpp/rl_sim.cc b/examples/rl_sim_cpp/rl_sim.cc index fdd900458..21c47c2f6 100644 --- a/examples/rl_sim_cpp/rl_sim.cc +++ b/examples/rl_sim_cpp/rl_sim.cc @@ -496,7 +496,7 @@ int rl_sim::init_rl() // Note: This requires C++14 or better using namespace std::placeholders; reinforcement_learning::oauth_callback_t callback = - std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4); + std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4); reinforcement_learning::register_default_factories_callback(callback); #endif } @@ -650,11 +650,8 @@ std::string rl_sim::create_event_id() return oss.str(); } -rl_sim::rl_sim(boost::program_options::variables_map vm) - : _options(std::move(vm)) - , _loop_kind(CB) +rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB) { - if (_options["ccb"].as()) { _loop_kind = CCB; } else if (_options["slates"].as()) { _loop_kind = Slates; } else if (_options["ca"].as()) { _loop_kind = CA; } diff --git a/examples/rl_sim_cpp/rl_sim.h b/examples/rl_sim_cpp/rl_sim.h index 40699bcac..15e63192a 100644 --- a/examples/rl_sim_cpp/rl_sim.h +++ b/examples/rl_sim_cpp/rl_sim.h @@ -9,7 +9,7 @@ #include "azure_credentials_provider.h" #ifdef LINK_AZURE_LIBS -#include +# include #endif #include "live_model.h" diff --git a/include/azure_credentials_provider.h b/include/azure_credentials_provider.h index 81783f634..a273e85e7 100644 --- a/include/azure_credentials_provider.h +++ b/include/azure_credentials_provider.h @@ -2,38 +2,36 @@ #ifdef LINK_AZURE_LIBS -#include -#include +# include +# include // These are needed because azure does a bad time conversion -#include -#include -#include -#include -#include +# include +# include +# include +# include +# include +# include -#include +# ifdef LINK_AZURE_LIBS +# include +# endif -#ifdef LINK_AZURE_LIBS -#include -#endif - -#include "err_constants.h" -#include "trace_logger.h" +# include "err_constants.h" +# include "trace_logger.h" namespace reinforcement_learning { -template +template class azure_credentials_provider { public: - template - azure_credentials_provider(Args&&... args) : - _creds(std::make_unique(std::forward(args)...)) + template + azure_credentials_provider(Args&&... args) : _creds(std::make_unique(std::forward(args)...)) { } int get_credentials(const std::vector& scopes, std::string& token_out, - std::chrono::system_clock::time_point& expiry_out, i_trace* trace) + std::chrono::system_clock::time_point& expiry_out, i_trace* trace) { using namespace Azure::Core; using namespace Azure::Core::Credentials; @@ -63,20 +61,20 @@ class azure_credentials_provider expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm)); } catch (AuthenticationException& e) - { + { TRACE_ERROR(trace, e.what()); return error_code::http_oauth_authentication_error; - } - catch (std::exception& e) - { + } + catch (std::exception& e) + { TRACE_ERROR(trace, e.what()); - return error_code::http_oauth_unexpected_error; - } - catch (...) - { + return error_code::http_oauth_unexpected_error; + } + catch (...) + { TRACE_ERROR(trace, "azure_credentials_provider: an unexpected unknown error occurred"); - return error_code::http_oauth_unexpected_error; - } + return error_code::http_oauth_unexpected_error; + } return error_code::success; } diff --git a/include/errors_data.h b/include/errors_data.h index 6d2c15e43..7747b0e5e 100644 --- a/include/errors_data.h +++ b/include/errors_data.h @@ -56,5 +56,6 @@ ERROR_CODE_DEFINITION(50, http_api_key_not_provided, "Http api key must be provi ERROR_CODE_DEFINITION(51, http_model_uri_not_provided, "Model Blob URI parameter was not passed in via configuration") ERROR_CODE_DEFINITION(52, static_model_load_error, "Static model passed in C# layer is not loading properly") ERROR_CODE_DEFINITION(53, http_oauth_authentication_error, "http request failed to authenticate") -ERROR_CODE_DEFINITION(54, http_oauth_unexpected_error, "http request failed with an unexpected error while retrieving a token") +ERROR_CODE_DEFINITION( + 54, http_oauth_unexpected_error, "http request failed with an unexpected error while retrieving a token") //! [Error Definitions] diff --git a/include/oauth_callback_fn.h b/include/oauth_callback_fn.h index aabe1c3bd..bb8d4fee3 100644 --- a/include/oauth_callback_fn.h +++ b/include/oauth_callback_fn.h @@ -1,13 +1,14 @@ #pragma once +#include "trace_logger.h" + #include #include #include #include -#include "trace_logger.h" namespace reinforcement_learning { -using oauth_callback_t = - std::function&, std::string&, std::chrono::system_clock::time_point&, i_trace* trace)>; +using oauth_callback_t = std::function&, std::string&, std::chrono::system_clock::time_point&, i_trace* trace)>; } \ No newline at end of file From 060ce8a5b6a513f76ff0b5e0353afc03f116e6ab Mon Sep 17 00:00:00 2001 From: James Longo Date: Mon, 12 Aug 2024 09:28:51 -0400 Subject: [PATCH 07/19] moved find dotnet-t4 to FindDotnet.cmake and check it for windows only. the build explicitly installs dotnet-t4 for windows. --- bindings/cs/rl.net/CMakeLists.txt | 8 -------- cmake/Modules/FindDotnet.cmake | 6 +++++- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/bindings/cs/rl.net/CMakeLists.txt b/bindings/cs/rl.net/CMakeLists.txt index e7258c6ee..ff3252ea0 100644 --- a/bindings/cs/rl.net/CMakeLists.txt +++ b/bindings/cs/rl.net/CMakeLists.txt @@ -1,11 +1,3 @@ -find_program(DOTNET_T4_EXECUTABLE NAMES t4) - -if (DOTNET_T4_EXECUTABLE) - message(STATUS "found dotnet-t4: ${DOTNET_T4_EXECUTABLE}") -else() - message(FATAL_ERROR "dotnet-t4 tool not found. install dotnet-t4 using: dotnet tool install -g dotnet-t4") -endif() - set(RL_NET_SOURCES Native/ErrorCallback.cs Native/GCHandleLifetime.cs diff --git a/cmake/Modules/FindDotnet.cmake b/cmake/Modules/FindDotnet.cmake index 846ec23fd..6deb3c8e5 100644 --- a/cmake/Modules/FindDotnet.cmake +++ b/cmake/Modules/FindDotnet.cmake @@ -1 +1,5 @@ -find_program(DOTNET_COMMAND "dotnet" REQUIRED) \ No newline at end of file +find_program(DOTNET_COMMAND "dotnet" REQUIRED) + +if(WIN32) + find_program(DOTNET_T4_COMMAND "t4" REQUIRED) +endif() From ac470e615c61f20aec808f4652e28c7662f12d2d Mon Sep 17 00:00:00 2001 From: James Longo Date: Mon, 12 Aug 2024 10:07:19 -0400 Subject: [PATCH 08/19] the x64 macos image was changed to macos-latest-large. see (https://github.com/actions/runner-images?tab=readme-ov-file#available-images) --- .github/workflows/asan.yml | 4 ++-- .github/workflows/build_rlclientlib.yml | 2 +- .github/workflows/build_vw_bp.yml | 2 +- .github/workflows/dotnet_nugets.yml | 4 ++-- .github/workflows/vcpkg_build.yml | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 17142d9f9..02e680065 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -18,8 +18,8 @@ jobs: strategy: fail-fast: false matrix: - #os: [windows-latest, ubuntu-latest, macos-latest] - os: [ubuntu-latest, macos-latest] # Temporarily remove windows asan + #os: [windows-latest, ubuntu-latest, macos-latest-large] + os: [ubuntu-latest, macos-latest-large] # Temporarily remove windows asan preset: [vcpkg-asan-debug, vcpkg-ubsan-debug] exclude: # UBSan not supported by MSVC on Windows diff --git a/.github/workflows/build_rlclientlib.yml b/.github/workflows/build_rlclientlib.yml index 3300a5ff2..17dddbba6 100644 --- a/.github/workflows/build_rlclientlib.yml +++ b/.github/workflows/build_rlclientlib.yml @@ -75,7 +75,7 @@ jobs: build-macos: # Mac build doesn't have any additional features enabled name: rlclientlib-${{ matrix.build_type }}-macos-latest - runs-on: macos-latest + runs-on: macos-latest-large strategy: fail-fast: false matrix: diff --git a/.github/workflows/build_vw_bp.yml b/.github/workflows/build_vw_bp.yml index 21c3e37e7..21e136189 100644 --- a/.github/workflows/build_vw_bp.yml +++ b/.github/workflows/build_vw_bp.yml @@ -26,7 +26,7 @@ jobs: config: - { os: "windows-latest", vcpkg_target_triplet: "x64-windows-static" } - { os: "ubuntu-latest", vcpkg_target_triplet: "x64-linux" } - - { os: "macos-latest", vcpkg_target_triplet: "x64-osx" } + - { os: "macos-latest-large", vcpkg_target_triplet: "x64-osx" } build: # Set the appropriate static runtime for MSVC on Windows # CMake ignores the CMAKE_MSVC_RUNTIME_LIBRARY option on other platforms diff --git a/.github/workflows/dotnet_nugets.yml b/.github/workflows/dotnet_nugets.yml index a383cda5a..2d5a2bb24 100644 --- a/.github/workflows/dotnet_nugets.yml +++ b/.github/workflows/dotnet_nugets.yml @@ -29,7 +29,7 @@ jobs: config: - { os: "windows-latest", runtime_id: "win-x64", vcpkg_target_triplet: "x64-windows-static" } - { os: "ubuntu-latest", runtime_id: "linux-x64", vcpkg_target_triplet: "x64-linux" } - - { os: "macos-latest", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" } + - { os: "macos-latest-large", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" } runs-on: ${{matrix.config.os}} steps: - uses: actions/checkout@v2 @@ -162,7 +162,7 @@ jobs: config: - { os: "windows-latest", runtime_id: "win-x64" } - { os: "ubuntu-latest", runtime_id: "linux-x64" } - - { os: "macos-latest", runtime_id: "osx-x64" } + - { os: "macos-latest-large", runtime_id: "osx-x64" } runs-on: ${{matrix.config.os}} steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index 4fb2fc27a..8f0664337 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest-large, windows-latest] preset: [vcpkg-debug, vcpkg-release] steps: - uses: actions/checkout@v3 From 56ff878122a5367e32fd5924c35c1b96f2b634ec Mon Sep 17 00:00:00 2001 From: James Longo Date: Mon, 12 Aug 2024 10:37:33 -0400 Subject: [PATCH 09/19] for now, moving to build runner macos-13 supporting x64 --- .github/workflows/asan.yml | 4 ++-- .github/workflows/build_rlclientlib.yml | 2 +- .github/workflows/build_vw_bp.yml | 2 +- .github/workflows/dotnet_nugets.yml | 4 ++-- .github/workflows/vcpkg_build.yml | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 02e680065..8a291819f 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -18,8 +18,8 @@ jobs: strategy: fail-fast: false matrix: - #os: [windows-latest, ubuntu-latest, macos-latest-large] - os: [ubuntu-latest, macos-latest-large] # Temporarily remove windows asan + #os: [windows-latest, ubuntu-latest, macos-13] + os: [ubuntu-latest, macos-13] # Temporarily remove windows asan preset: [vcpkg-asan-debug, vcpkg-ubsan-debug] exclude: # UBSan not supported by MSVC on Windows diff --git a/.github/workflows/build_rlclientlib.yml b/.github/workflows/build_rlclientlib.yml index 17dddbba6..20aeefeee 100644 --- a/.github/workflows/build_rlclientlib.yml +++ b/.github/workflows/build_rlclientlib.yml @@ -75,7 +75,7 @@ jobs: build-macos: # Mac build doesn't have any additional features enabled name: rlclientlib-${{ matrix.build_type }}-macos-latest - runs-on: macos-latest-large + runs-on: macos-13 strategy: fail-fast: false matrix: diff --git a/.github/workflows/build_vw_bp.yml b/.github/workflows/build_vw_bp.yml index 21e136189..53c625eab 100644 --- a/.github/workflows/build_vw_bp.yml +++ b/.github/workflows/build_vw_bp.yml @@ -26,7 +26,7 @@ jobs: config: - { os: "windows-latest", vcpkg_target_triplet: "x64-windows-static" } - { os: "ubuntu-latest", vcpkg_target_triplet: "x64-linux" } - - { os: "macos-latest-large", vcpkg_target_triplet: "x64-osx" } + - { os: "macos-13", vcpkg_target_triplet: "x64-osx" } build: # Set the appropriate static runtime for MSVC on Windows # CMake ignores the CMAKE_MSVC_RUNTIME_LIBRARY option on other platforms diff --git a/.github/workflows/dotnet_nugets.yml b/.github/workflows/dotnet_nugets.yml index 2d5a2bb24..44116260b 100644 --- a/.github/workflows/dotnet_nugets.yml +++ b/.github/workflows/dotnet_nugets.yml @@ -29,7 +29,7 @@ jobs: config: - { os: "windows-latest", runtime_id: "win-x64", vcpkg_target_triplet: "x64-windows-static" } - { os: "ubuntu-latest", runtime_id: "linux-x64", vcpkg_target_triplet: "x64-linux" } - - { os: "macos-latest-large", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" } + - { os: "macos-13", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" } runs-on: ${{matrix.config.os}} steps: - uses: actions/checkout@v2 @@ -162,7 +162,7 @@ jobs: config: - { os: "windows-latest", runtime_id: "win-x64" } - { os: "ubuntu-latest", runtime_id: "linux-x64" } - - { os: "macos-latest-large", runtime_id: "osx-x64" } + - { os: "macos-13", runtime_id: "osx-x64" } runs-on: ${{matrix.config.os}} steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index 8f0664337..7d2b456fd 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest-large, windows-latest] + os: [ubuntu-latest, macos-13, windows-latest] preset: [vcpkg-debug, vcpkg-release] steps: - uses: actions/checkout@v3 From f05a552b17bf2edc99993055185c1ce11230fd89 Mon Sep 17 00:00:00 2001 From: James Longo Date: Wed, 14 Aug 2024 15:29:18 -0400 Subject: [PATCH 10/19] updates to provide default azure credential implementations via the factory resolver cmake changes to link with azure identity when RL_LINK_AZURE_LIBS is defined --- CMakeLists.txt | 5 - bindings/cs/CMakeLists.txt | 7 + examples/basic_usage_cpp/CMakeLists.txt | 6 + examples/override_interface/CMakeLists.txt | 6 + examples/rl_sim_cpp/main.cc | 3 +- examples/rl_sim_cpp/rl_sim.cc | 18 +- examples/test_cpp/CMakeLists.txt | 6 + include/azure_credentials_provider.h | 114 +++++++--- include/constants.h | 12 ++ include/factory_resolver.h | 92 +++++++- include/oauth_callback_fn.h | 39 +++- rlclientlib/CMakeLists.txt | 6 + rlclientlib/azure_factories.cc | 204 ++++++++++++++++-- rlclientlib/azure_factories.h | 5 +- rlclientlib/factory_resolver.cc | 11 +- .../restapi_data_transport_oauth.cc | 12 +- .../model_mgmt/restapi_data_transport_oauth.h | 6 +- rlclientlib/utility/api_header_token.h | 10 +- test_tools/example_gen/CMakeLists.txt | 5 + test_tools/sender_test/CMakeLists.txt | 6 + unit_test/CMakeLists.txt | 6 + 21 files changed, 492 insertions(+), 87 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a12cc277a..a0aa41108 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,11 +6,6 @@ if(POLICY CMP0091) cmake_policy(SET CMP0091 NEW) endif() -# ensure all of the build tools generate the same output on all platforms -# note: this change was made since building with Ninja does not add suffixes -# but, using the VS generator does. -set(CMAKE_DEBUG_POSTFIX "") - if(WIN32) # Due to needing to configure the CMAKE platform, this needs to be included before the # top-level project() declaration. diff --git a/bindings/cs/CMakeLists.txt b/bindings/cs/CMakeLists.txt index 70bf843ca..333e31288 100644 --- a/bindings/cs/CMakeLists.txt +++ b/bindings/cs/CMakeLists.txt @@ -1,6 +1,13 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/") include(FindDotnet) +# note: this change was made since building with Ninja does not add suffixes +# but, using the VS generator does. rl.net uses dllimport to load rlnetnative. +# This is a workaround to make sure the correct dll is used. +if (WIN32 AND CMAKE_GENERATOR MATCHES "Visual Studio") + set(CMAKE_DEBUG_POSTFIX "") +endif() + add_subdirectory(rl.net.native) add_subdirectory(rl.net) add_subdirectory(rl.net.cli) diff --git a/examples/basic_usage_cpp/CMakeLists.txt b/examples/basic_usage_cpp/CMakeLists.txt index 1edde3b94..bd1b5fa00 100644 --- a/examples/basic_usage_cpp/CMakeLists.txt +++ b/examples/basic_usage_cpp/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(basic_usage_cpp.out ) target_link_libraries(basic_usage_cpp.out PRIVATE rlclientlib) + +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(basic_usage_cpp.out PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(basic_usage_cpp.out PRIVATE Azure::azure-identity) +endif() diff --git a/examples/override_interface/CMakeLists.txt b/examples/override_interface/CMakeLists.txt index 2f4c28c07..73264c495 100644 --- a/examples/override_interface/CMakeLists.txt +++ b/examples/override_interface/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(override_interface.out ) target_link_libraries(override_interface.out PRIVATE rlclientlib) + +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(override_interface.out PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(override_interface.out PRIVATE Azure::azure-identity) +endif() diff --git a/examples/rl_sim_cpp/main.cc b/examples/rl_sim_cpp/main.cc index 324d6051e..5852fbaef 100644 --- a/examples/rl_sim_cpp/main.cc +++ b/examples/rl_sim_cpp/main.cc @@ -38,8 +38,7 @@ po::variables_map process_cmd_line(const int argc, char** argv) "random_seed", po::value()->default_value(rand()), "Random seed. Default is random")( "delay", po::value()->default_value(2000), "Delay between events in ms")( "quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value()->default_value(true), - "Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")( - "azure_oauth_factories", po::value()->default_value(false), "Use oauth for azure factores. Default false"); + "Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats"); po::variables_map vm; store(parse_command_line(argc, argv, desc), vm); diff --git a/examples/rl_sim_cpp/rl_sim.cc b/examples/rl_sim_cpp/rl_sim.cc index 21c47c2f6..24896a327 100644 --- a/examples/rl_sim_cpp/rl_sim.cc +++ b/examples/rl_sim_cpp/rl_sim.cc @@ -490,16 +490,18 @@ int rl_sim::init_rl() sender_factory = &factory; } // probably incompatible with the throughput option? - else if (_options["azure_oauth_factories"].as()) - { #ifdef LINK_AZURE_LIBS - // Note: This requires C++14 or better - using namespace std::placeholders; - reinforcement_learning::oauth_callback_t callback = - std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4); - reinforcement_learning::register_default_factories_callback(callback); + // Note: The azure_oauth_factories switch has been removed since registering the factories by default does not + // directly utilize the callback. The factory implementations need to be configured in the configuration file + // for them to function (see factory_resolver.h). + // Note: If USE_AZURE_FACTORIES is defined, the default factory implementations will be different (see + // factory_resolver.h). + // Note: This requires C++14 or better since the Azure libraries require C++14 or better. + using namespace std::placeholders; + reinforcement_learning::oauth_callback_t callback = + std::bind(&azure_credentials_provider_t::get_token, &_creds, _1, _2, _3, _4); + reinforcement_learning::register_default_factories_callback(callback); #endif - } // Initialize the API _rl = std::unique_ptr(new r::live_model(config, _on_error, this, diff --git a/examples/test_cpp/CMakeLists.txt b/examples/test_cpp/CMakeLists.txt index ce5bec8e3..5654b937c 100644 --- a/examples/test_cpp/CMakeLists.txt +++ b/examples/test_cpp/CMakeLists.txt @@ -10,3 +10,9 @@ add_executable(rl_test.out target_include_directories(rl_test.out PRIVATE $) target_link_libraries(rl_test.out PRIVATE Boost::program_options rlclientlib) + +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(rl_test.out PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(rl_test.out PRIVATE Azure::azure-identity) +endif() diff --git a/include/azure_credentials_provider.h b/include/azure_credentials_provider.h index a273e85e7..f7f440f06 100644 --- a/include/azure_credentials_provider.h +++ b/include/azure_credentials_provider.h @@ -1,37 +1,87 @@ #pragma once #ifdef LINK_AZURE_LIBS +# include "err_constants.h" +# include "oauth_callback_fn.h" -# include -# include -// These are needed because azure does a bad time conversion -# include +# include # include -# include -# include # include -# include +# include -# ifdef LINK_AZURE_LIBS -# include -# endif +// These are needed because azure does a bad time conversion +# include -# include "err_constants.h" -# include "trace_logger.h" +# include namespace reinforcement_learning { +namespace +{ +/** + * @brief Get the GMT offset from the local time. + */ +inline std::chrono::system_clock::duration get_gmt_offset() +{ + auto get_time_point = [](std::tm& tm) + { + // set the tm_isdst field to -1 to let mktime determine if DST is in effect + tm.tm_isdst = -1; + return std::chrono::system_clock::from_time_t(std::mktime(&tm)); + }; + std::time_t now = std::time(nullptr); + std::tm local_tm{}; + localtime_s(&local_tm, &now); + std::tm gmt_tm{}; + gmtime_s(&gmt_tm, &now); + return get_time_point(local_tm) - get_time_point(gmt_tm); +} +} // namespace + +/** + * @brief A template class that provides Azure OAuth credentials. + * + * This class is a template that requires a type T which must be a subclass of + * Azure::Core::Credentials::TokenCredential. It implements the i_oauth_credentials_provider + * interface to provide OAuth tokens for Azure services. + * + * @tparam T The type of the Azure credential, must be a subclass of Azure::Core::Credentials::TokenCredential. + */ template -class azure_credentials_provider +class azure_credentials_provider : public i_oauth_credentials_provider { + static_assert(std::is_base_of::value, + "T must be a subclass of Azure::Core::Credentials::TokenCredential"); + public: + /** + * @brief Constructs an azure_credentials_provider with the given arguments. + * + * @tparam Args The types of the arguments to forward to the constructor of T. + * @param args The arguments to forward to the constructor of T. + */ template azure_credentials_provider(Args&&... args) : _creds(std::make_unique(std::forward(args)...)) { + _gmt_offset = get_gmt_offset(); } - int get_credentials(const std::vector& scopes, std::string& token_out, - std::chrono::system_clock::time_point& expiry_out, i_trace* trace) + /** + * @brief Default destructor. + */ + ~azure_credentials_provider() override = default; + + /** + * @brief Retrieves an OAuth token for the given scopes. + * + * @param scopes The scopes for which the token is requested. + * @param token_out The output parameter where the token will be stored. + * @param expiry_out The output parameter where the token expiry time will be stored. + * @param trace The trace object for logging. + * @return int The error code indicating success or failure. + */ + int get_token(const std::vector& scopes, std::string& token_out, + std::chrono::system_clock::time_point& expiry_out, i_trace* trace) override { using namespace Azure::Core; using namespace Azure::Core::Credentials; @@ -48,17 +98,7 @@ class azure_credentials_provider } TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token"); token_out = auth.Token; - - // Casting from an azure DateTime object to a time_point does the calculation - // incorrectly. The expiration is returned as a local time, but the library - // assumes that it is GMT, and converts the value incorrectly. - // See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075 - // expiry_out = static_cast(auth.ExpiresOn); - std::string dt_string = auth.ExpiresOn.ToString(); - std::tm tm = {}; - std::istringstream ss(dt_string); - ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ"); - expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm)); + expiry_out = get_expiry_time(auth); } catch (AuthenticationException& e) { @@ -78,10 +118,28 @@ class azure_credentials_provider return error_code::success; } +private: + /** + * @brief Gets the adjusted expiry time of the given access token. + * + * @param access_token The access token whose expiry time is to be calculated. + * @return std::chrono::system_clock::time_point The calculated expiry time. + * @remarks This function is needed because Azure library returns local time + * instead of GMT time. + */ + std::chrono::system_clock::time_point get_expiry_time(const Azure::Core::Credentials::AccessToken& access_token) + { + // Casting from an azure DateTime object to a time_point does the calculation + // incorrectly. The expiration is returned as a local time, but the library + // assumes that it is GMT, and converts the value incorrectly. + // See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075 + return static_cast(access_token.ExpiresOn) - _gmt_offset; + } + private: std::unique_ptr _creds; mutable std::mutex _creds_mtx; + std::chrono::system_clock::duration _gmt_offset; }; } // namespace reinforcement_learning - -#endif +#endif // LINK_AZURE_LIBS diff --git a/include/constants.h b/include/constants.h index c4f3c632e..c7ee7e528 100644 --- a/include/constants.h +++ b/include/constants.h @@ -27,6 +27,10 @@ const char* const PROTOCOL_VERSION = "protocol.version"; const char* const HTTP_API_KEY = "http.api.key"; const char* const HTTP_API_HEADER_KEY_NAME = "http.api.header.key.name"; const char* const HTTP_API_OAUTH_TOKEN_TYPE = "http.api.oauth.token.type"; +const char* const AZURE_OAUTH_CREDENTIAL_TYPE = "azure.oauth.credential.type"; +const char* const AZURE_OAUTH_CREDENTIAL_CLIENTID = "azure.oauth.credential.clientid"; +const char* const AZURE_OAUTH_CREDENTIAL_TENANTID = "azure.oauth.credential.tenantid"; +const char* const AZURE_OAUTH_CREDENTIAL_CLIENTSECRET = "azure.oauth.credential.clientsecret"; const char* const AUDIT_ENABLED = "audit.enabled"; const char* const AUDIT_OUTPUT_PATH = "audit.output.path"; @@ -120,6 +124,7 @@ const char* const NO_MODEL_DATA = "NO_MODEL_DATA"; const char* const HTTP_MODEL_DATA = "HTTP_MODEL_DATA"; const char* const FILE_MODEL_DATA = "FILE_MODEL_DATA"; const char* const HTTP_MODEL_DATA_OAUTH = "HTTP_MODEL_DATA_OAUTH"; +const char* const HTTP_MODEL_DATA_OAUTH_AZ = "HTTP_MODEL_DATA_OAUTH_AZ"; const char* const VW = "VW"; const char* const PASSTHROUGH_PDF_MODEL = "PASSTHROUGH_PDF"; const char* const EPISODE_EH_SENDER = "EPISODE_EH_SENDER"; @@ -132,8 +137,11 @@ const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER"; const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER"; const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER"; const char* const EPISODE_HTTP_API_SENDER_OAUTH = "EPISODE_HTTP_API_SENDER_OAUTH"; +const char* const EPISODE_HTTP_API_SENDER_OAUTH_AZ = "EPISODE_HTTP_API_SENDER_OAUTH_AZ"; const char* const OBSERVATION_HTTP_API_SENDER_OAUTH = "OBSERVATION_HTTP_API_SENDER_OAUTH"; +const char* const OBSERVATION_HTTP_API_SENDER_OAUTH_AZ = "OBSERVATION_HTTP_API_SENDER_OAUTH_AZ"; const char* const INTERACTION_HTTP_API_SENDER_OAUTH = "INTERACTION_HTTP_API_SENDER_OAUTH"; +const char* const INTERACTION_HTTP_API_SENDER_OAUTH_AZ = "INTERACTION_HTTP_API_SENDER_OAUTH_AZ"; const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER"; const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER"; const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER"; @@ -146,6 +154,10 @@ const char* const CONTENT_ENCODING_DEDUP = "DEDUP"; const char* const HTTP_API_DEFAULT_HEADER_KEY_NAME = "Ocp-Apim-Subscription-Key"; const char* const HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE = "Bearer"; const char* const TRACE_LOG_LEVEL_DEFAULT = "info"; +const char* const AZURE_OAUTH_CREDENTIALS_DEFAULT = "DEFAULT_CREDENTIAL"; +const char* const AZURE_OAUTH_CREDENTIALS_MANAGEDIDENTITY = "MANAGED_IDENTITY"; +const char* const AZURE_OAUTH_CREDENTIALS_AZURECLI = "AZURECLI"; +const char* const AZURE_OAUTH_CREDENTIALS_CLIENTSECRET = "CLIENT_SECRET"; const char* const QUEUE_MODE_DROP = "DROP"; const char* const QUEUE_MODE_BLOCK = "BLOCK"; diff --git a/include/factory_resolver.h b/include/factory_resolver.h index a4cae3316..884d90b8b 100644 --- a/include/factory_resolver.h +++ b/include/factory_resolver.h @@ -1,4 +1,5 @@ #pragma once +#include "azure_credentials_provider.h" #include "oauth_callback_fn.h" #include "object_factory.h" @@ -28,6 +29,60 @@ class error_callback_fn; * provide the mechanism used when logging internal events in the API implementation. */ using trace_logger_factory_t = utility::object_factory; +/** + * @brief Factory to create Azure credential providers used to retrieve Azure OAuth tokens. + * Advanced extension point: Register another implementation of i_oauth_credentials_provider to + * provide the mechanism used when providing Azure OAuth token to the API implementation. + * + * @remark azure_cred_provider_factory_t provides the necessary Azure credential implementation used + * by the library to retrieve Azure OAuth tokens. The following Azure credential providers are + * provided by the library: + * + * - azure_credentials_provider - Provides Azure OAuth token using client + * secret. + * - azure_credentials_provider - Provides Azure OAuth token using managed identity. + * - azure_credentials_provider - Provides Azure OAuth token using Azure CLI. + * - azure_credentials_provider - Provides Azure OAuth token using default + * credential. + * + * The library provides a default implementation of azure_credentials_provider. + * + * The following configuration keys are used to configure the Azure credential provider: + * + * - "azure.oauth.credential.type" - The type of Azure credential provider to use. The following values are supported: + * - "CLIENT_SECRET" - Use azure_credentials_provider to provide Azure + * OAuth token. + * - "MANAGED_IDENTITY" - Use azure_credentials_provider to provide Azure OAuth + * token. + * - "AZURECLI" - Use azure_credentials_provider to provide Azure OAuth token. + * - "DEFAULT_CREDENTIAL" - Use azure_credentials_provider to provide Azure + * OAuth token. + * + * The following configuration keys are used to configure the Azure credential provider: + * + * - CLIENT_SECRET: + * - "azure.oauth.credential.clientid" + * - "azure.oauth.credential.tenantid" + * - "azure.oauth.credential.clientsecret" + * + * - MANAGED_IDENTITY: + * - "azure.oauth.credential.clientid" + * + * - AZURECLI: + * - "azure.oauth.credential.tenantid" + * + * To enable Azure OAuth token authentication, the following configuration keys may be set: + * + * "model.source": "HTTP_MODEL_DATA_OAUTH_AZ" + * "interaction.sender.implementation": "INTERACTION_HTTP_API_SENDER_OAUTH_AZ" + * "observation.sender.implementation": "OBSERVATION_HTTP_API_SENDER_OAUTH_AZ" + * "episode.sender.implementation": "EPISODE_HTTP_API_SENDER_OAUTH_AZ" + * + * For a custom implementation, register your implementations using the factory. Or, @see + * register_default_factories_callback if you don't want to use the factory. + */ +using azure_cred_provider_factory_t = + utility::object_factory; /** * @brief Factory to create model used in inference. * Advanced extension point: Register another implementation of i_model to @@ -60,6 +115,7 @@ extern model_factory_t& model_factory; extern sender_factory_t& sender_factory; extern trace_logger_factory_t& trace_logger_factory; extern time_provider_factory_t& time_provider_factory; +extern azure_cred_provider_factory_t& azure_cred_provider_factory; // For proper static intialization // Check https://en.wikibooks.org/wiki/More_C++_Idioms/Nifty_Counter for explanation @@ -75,10 +131,40 @@ struct factory_initializer // only one translation unit will initialize it static factory_initializer _init; -// no-op if USE_AZURE_FACTORIES is not defined /** - * @brief Register default factories with an authentication callback + * @brief Register default factories with an authentication callback. + * @remark This function can be called to register factories for + * retrieving auth tokens. This is useful when control over retrieving + * tokens is needed and can be provided by the application. + * + * To enable this feature, the application must call this function prior to + * calling API methods that retrieve sender/receiver objects that require it. + * The configuration file should define some or all of the following keys: + * + * "model.source": "HTTP_MODEL_DATA_OAUTH" + * "episode.sender.implementation": "EPISODE_HTTP_API_SENDER_OAUTH" + * "interaction.sender.implementation": "INTERACTION_HTTP_API_SENDER_OAUTH" + * "observation.sender.implementation": "OBSERVATION_HTTP_API_SENDER_OAUTH" + * + * NOTE: The defaults for the above keys are as follows: + * + * -- If USE_AZURE_FACTORIES is NOT defined as part of the build: + * + * "model.source": "NO_MODEL_DATA" + * "episode.sender.implementation": "EPISODE_FILE_SENDER" + * "interaction.sender.implementation": "INTERACTION_FILE_SENDER" + * "observation.sender.implementation": "OBSERVATION_FILE_SENDER" + * "time_provider.implementation": "CLOCK_TIME_PROVIDER" + * + * -- If USE_AZURE_FACTORIES is defined as part of the build: + * + * "model.source": "AZURE_STORAGE_BLOB" + * "episode.sender.implementation": "EPISODE_EH_SENDER" + * "interaction.sender.implementation": "INTERACTION_EH_SENDER" + * "observation.sender.implementation": "OBSERVATION_EH_SENDER" + * "time_provider.implementation": "NULL_TIME_PROVIDER" + * + * See the rl_sim example application to see how to use this feature. */ void register_default_factories_callback(oauth_callback_t& callback); - } // namespace reinforcement_learning diff --git a/include/oauth_callback_fn.h b/include/oauth_callback_fn.h index bb8d4fee3..a71906774 100644 --- a/include/oauth_callback_fn.h +++ b/include/oauth_callback_fn.h @@ -9,6 +9,43 @@ namespace reinforcement_learning { +class i_oauth_credentials_provider +{ +public: + virtual int get_token(const std::vector& scopes, std::string& token_out, + std::chrono::system_clock::time_point& expiry_out, i_trace* trace) = 0; + virtual ~i_oauth_credentials_provider() {} +}; + using oauth_callback_t = std::function&, std::string&, std::chrono::system_clock::time_point&, i_trace* trace)>; -} \ No newline at end of file + +/** + * @brief Wraps a callback function into an i_oauth_credentials_provider interface + * @remark azure_cred_provider_cb_wrapper was introduced as part of a set of changes + * that unify the way we handle Azure credentials by following the same pattern + * using the factory resolver. This class allows us to avoid breaking changes + * with respect to the existing codebase. + */ +class oauth_cred_provider_cb_wrapper : public i_oauth_credentials_provider +{ +public: + /** + * @brief Constructor + * @param cb The callback function to be wrapped + * @remark This wrapper keeps a reference to the callback function. Take note of the + * lifetime of the callback function to avoid disaster. + */ + oauth_cred_provider_cb_wrapper(oauth_callback_t& cb) : _cb(cb) {} + ~oauth_cred_provider_cb_wrapper() override = default; + + int get_token(const std::vector& scopes, std::string& token_out, + std::chrono::system_clock::time_point& expiry_out, i_trace* trace) override + { + return _cb(scopes, token_out, expiry_out, trace); + } + +private: + oauth_callback_t& _cb; +}; +} // namespace reinforcement_learning \ No newline at end of file diff --git a/rlclientlib/CMakeLists.txt b/rlclientlib/CMakeLists.txt index cb37705d8..7c3fabb7c 100644 --- a/rlclientlib/CMakeLists.txt +++ b/rlclientlib/CMakeLists.txt @@ -214,6 +214,12 @@ if(vw_USE_AZURE_FACTORIES) target_compile_definitions(rlclientlib PRIVATE USE_AZURE_FACTORIES) endif() +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(rlclientlib PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(rlclientlib PRIVATE Azure::azure-identity) +endif() + if(RL_USE_ZSTD) target_compile_definitions(rlclientlib PRIVATE USE_ZSTD) target_link_libraries(rlclientlib PRIVATE libzstd_static) diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 86a981b09..45e62dd42 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -2,6 +2,7 @@ #include "constants.h" #include "factory_resolver.h" +#include "azure_credentials_provider.h" #include "logger/event_logger.h" #include "logger/http_transport_client.h" #include "model_mgmt/restapi_data_transport.h" @@ -13,6 +14,13 @@ #include +#ifdef LINK_AZURE_LIBS +# include +# include +# include +# include +#endif + namespace reinforcement_learning { namespace m = model_management; @@ -37,15 +45,35 @@ int observation_api_sender_create(std::unique_ptr& retval, const u::co int interaction_api_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int oauth_restapi_data_transport_create(oauth_callback_t& callback, std::unique_ptr& retval, +int oauth_restapi_data_transport_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status); -int episode_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, +int episode_api_sender_oauth_cb_create(oauth_callback_t& callback, + std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int observation_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, +int observation_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int interaction_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, +int interaction_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int oauth_restapi_data_transport_create(std::unique_ptr& retval, + const u::configuration& config, i_trace* trace_logger, api_status* status); +int episode_api_sender_oauth_create(std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int observation_api_sender_oauth_create(std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int interaction_api_sender_oauth_create(std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); + +int azure_default_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status); +int azure_managed_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status); +int azure_azurecli_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status); +int azure_clientsecret_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status); + + void register_azure_factories() { data_transport_factory.register_type(value::AZURE_STORAGE_BLOB, restapi_data_transport_create); @@ -61,20 +89,33 @@ void register_azure_factories() sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER, observation_api_sender_create); sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER, interaction_api_sender_create); sender_factory.register_type(value::EPISODE_HTTP_API_SENDER, episode_api_sender_create); + + // register default azure credentials provider factories + azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_DEFAULT, azure_default_cred_provider_create); + azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_MANAGEDIDENTITY, azure_managed_cred_provider_create); + azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_AZURECLI, azure_azurecli_cred_provider_create); + azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_CLIENTSECRET, azure_clientsecret_cred_provider_create); + + // register built-in azure oauth factories + data_transport_factory.register_type(value::HTTP_MODEL_DATA_OAUTH_AZ, oauth_restapi_data_transport_create); + sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER_OAUTH_AZ, observation_api_sender_oauth_create); + sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER_OAUTH_AZ, interaction_api_sender_oauth_create); + sender_factory.register_type(value::EPISODE_HTTP_API_SENDER_OAUTH_AZ, episode_api_sender_oauth_create); } void register_azure_oauth_factories(oauth_callback_t& callback) { + // user provided callback for oauth token // TODO: bind functions? using namespace std::placeholders; data_transport_factory.register_type( - value::HTTP_MODEL_DATA_OAUTH, std::bind(oauth_restapi_data_transport_create, callback, _1, _2, _3, _4)); + value::HTTP_MODEL_DATA_OAUTH, std::bind(oauth_restapi_data_transport_cb_create, callback, _1, _2, _3, _4)); sender_factory.register_type(value::OBSERVATION_HTTP_API_SENDER_OAUTH, - std::bind(observation_api_sender_oauth_create, callback, _1, _2, _3, _4, _5)); + std::bind(observation_api_sender_oauth_cb_create, callback, _1, _2, _3, _4, _5)); sender_factory.register_type(value::INTERACTION_HTTP_API_SENDER_OAUTH, - std::bind(interaction_api_sender_oauth_create, callback, _1, _2, _3, _4, _5)); - sender_factory.register_type( - value::EPISODE_HTTP_API_SENDER_OAUTH, std::bind(episode_api_sender_oauth_create, callback, _1, _2, _3, _4, _5)); + std::bind(interaction_api_sender_oauth_cb_create, callback, _1, _2, _3, _4, _5)); + sender_factory.register_type(value::EPISODE_HTTP_API_SENDER_OAUTH, + std::bind(episode_api_sender_oauth_cb_create, callback, _1, _2, _3, _4, _5)); } int restapi_data_transport_create(std::unique_ptr& retval, const u::configuration& config, @@ -107,6 +148,14 @@ std::string build_eh_url(const char* eh_host, const char* eh_name) return url; } +int get_azure_credential_provider(std::unique_ptr& retval, const u::configuration& config, + i_trace* trace_logger, api_status* status) +{ + const auto* const provider_type = + config.get(name::AZURE_OAUTH_CREDENTIAL_TYPE, value::AZURE_OAUTH_CREDENTIALS_DEFAULT); + return azure_cred_provider_factory.create(retval, provider_type, config, trace_logger, status); +} + int episode_sender_create(std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { @@ -133,16 +182,34 @@ int create_apim_http_api_sender(std::unique_ptr& retval, const u::conf return error_code::success; } -int create_apim_http_api_oauth_sender(oauth_callback_t& callback, std::unique_ptr& retval, +int create_apim_http_api_oauth_cb_sender(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, const char* api_host, int tasks_limit, int max_http_retries, std::chrono::milliseconds max_http_retry_duration, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { + // c++11 support (no make_unique) - azure libraries require c++14, but we might not be building with azure + auto cred_provider = azure_cred_provider_cb_wrapper_t(new oauth_cred_provider_cb_wrapper(callback)); + i_http_client* client = nullptr; + RETURN_IF_FAIL(create_http_client(api_host, cfg, &client, status)); + retval.reset(new http_transport_client>(client, tasks_limit, + max_http_retries, max_http_retry_duration, trace_logger, error_cb, std::move(cred_provider), + "https://eventhubs.azure.net//.default")); + return error_code::success; +} + +int create_apim_http_api_oauth_sender(std::unique_ptr& retval, + const u::configuration& cfg, const char* api_host, int tasks_limit, int max_http_retries, + std::chrono::milliseconds max_http_retry_duration, error_callback_fn* error_cb, i_trace* trace_logger, + api_status* status) +{ + std::unique_ptr cred_provider; + int ret_code = get_azure_credential_provider(cred_provider, cfg, trace_logger, status); + if (ret_code != error_code::success) { return ret_code; } i_http_client* client = nullptr; RETURN_IF_FAIL(create_http_client(api_host, cfg, &client, status)); retval.reset( new http_transport_client>(client, tasks_limit, max_http_retries, - max_http_retry_duration, trace_logger, error_cb, callback, "https://eventhubs.azure.net//.default")); + max_http_retry_duration, trace_logger, error_cb, std::move(cred_provider), "https://eventhubs.azure.net//.default")); return error_code::success; } @@ -211,46 +278,145 @@ int interaction_sender_create(std::unique_ptr& retval, const u::config return error_code::success; } -int oauth_restapi_data_transport_create(oauth_callback_t& callback, std::unique_ptr& retval, +int oauth_restapi_data_transport_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status) { + // c++11 support (no make_unique) - azure libraries require c++14, but we might not be building with azure + auto cred_provider = azure_cred_provider_cb_wrapper_t(new oauth_cred_provider_cb_wrapper(callback)); const auto* model_uri = config.get(name::MODEL_BLOB_URI, nullptr); if (model_uri == nullptr) { RETURN_ERROR(trace_logger, status, http_model_uri_not_provided); } i_http_client* client = nullptr; RETURN_IF_FAIL(create_http_client(model_uri, config, &client, status)); retval.reset(new m::restapi_data_transport_oauth(std::unique_ptr(client), config, - m::model_source::HTTP_API, trace_logger, callback, "https://storage.azure.com//.default")); + m::model_source::HTTP_API, trace_logger, std::move(cred_provider), "https://storage.azure.com//.default")); return error_code::success; } -int episode_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, +int episode_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::EPISODE_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_oauth_sender(callback, retval, cfg, api_host, + return create_apim_http_api_oauth_cb_sender(callback, retval, cfg, api_host, cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRIES, 4), cfg.get_int(name::EPISODE_APIM_TASKS_LIMIT, 4), std::chrono::milliseconds(cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); } -int observation_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, +int observation_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::OBSERVATION_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_oauth_sender(callback, retval, cfg, api_host, + return create_apim_http_api_oauth_cb_sender(callback, retval, cfg, api_host, cfg.get_int(name::OBSERVATION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRIES, 4), std::chrono::milliseconds(cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); } -int interaction_api_sender_oauth_create(oauth_callback_t& callback, std::unique_ptr& retval, +int interaction_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::INTERACTION_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_oauth_sender(callback, retval, cfg, api_host, + return create_apim_http_api_oauth_cb_sender(callback, retval, cfg, api_host, cfg.get_int(name::INTERACTION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRIES, 4), std::chrono::milliseconds(cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); } +int oauth_restapi_data_transport_create(std::unique_ptr& retval, + const u::configuration& config, i_trace* trace_logger, api_status* status) +{ + std::unique_ptr cred_provider; + int ret_code = get_azure_credential_provider(cred_provider, config, trace_logger, status); + if (ret_code != error_code::success) { return ret_code; } + const auto* model_uri = config.get(name::MODEL_BLOB_URI, nullptr); + if (model_uri == nullptr) { RETURN_ERROR(trace_logger, status, http_model_uri_not_provided); } + i_http_client* client = nullptr; + RETURN_IF_FAIL(create_http_client(model_uri, config, &client, status)); + retval.reset(new m::restapi_data_transport_oauth(std::unique_ptr(client), config, + m::model_source::HTTP_API, trace_logger, std::move(cred_provider), "https://storage.azure.com//.default")); + return error_code::success; +} + +int episode_api_sender_oauth_create(std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::EPISODE_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_oauth_sender(retval, cfg, api_host, + cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRIES, 4), cfg.get_int(name::EPISODE_APIM_TASKS_LIMIT, 4), + std::chrono::milliseconds(cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + +int observation_api_sender_oauth_create(std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::OBSERVATION_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_oauth_sender(retval, cfg, api_host, + cfg.get_int(name::OBSERVATION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRIES, 4), + std::chrono::milliseconds(cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + +int interaction_api_sender_oauth_create(std::unique_ptr& retval, + const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +{ + const auto* const api_host = cfg.get(name::INTERACTION_HTTP_API_HOST, "localhost:8080"); + return create_apim_http_api_oauth_sender(retval, cfg, api_host, + cfg.get_int(name::INTERACTION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRIES, 4), + std::chrono::milliseconds(cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, + trace_logger, status); +} + +int azure_default_cred_provider_create(std::unique_ptr& retval, const u::configuration& cfg, + i_trace* trace_logger, api_status* status) +{ + int ret_code = error_code::not_supported; +#ifdef LINK_AZURE_LIBS + retval = std::make_unique>(); + ret_code = error_code::success; +#endif + return ret_code; +} + +int azure_managed_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status) +{ + int ret_code = error_code::not_supported; +#ifdef LINK_AZURE_LIBS + const auto client_id = cfg.get(name::AZURE_OAUTH_CREDENTIAL_CLIENTID, ""); + retval = std::make_unique>(client_id); + ret_code = error_code::success; +#endif + return ret_code; +} + +int azure_azurecli_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status) +{ + int ret_code = error_code::not_supported; +#ifdef LINK_AZURE_LIBS + const auto tenant_id = cfg.get(name::AZURE_OAUTH_CREDENTIAL_TENANTID, ""); + Azure::Identity::AzureCliCredentialOptions options; + options.TenantId = tenant_id; + retval = std::make_unique>(options); + ret_code = error_code::success; +#endif + return ret_code; +} + +int azure_clientsecret_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status) +{ + int ret_code = error_code::not_supported; +#ifdef LINK_AZURE_LIBS + const auto client_id = cfg.get(name::AZURE_OAUTH_CREDENTIAL_CLIENTID, ""); + const auto tenant_id = cfg.get(name::AZURE_OAUTH_CREDENTIAL_TENANTID, ""); + const auto secret = cfg.get(name::AZURE_OAUTH_CREDENTIAL_CLIENTSECRET, ""); + retval = + std::make_unique>(tenant_id, client_id, secret); + ret_code = error_code::success; +#endif + return ret_code; +} + } // namespace reinforcement_learning diff --git a/rlclientlib/azure_factories.h b/rlclientlib/azure_factories.h index b8d61f5a8..f6adeb53d 100644 --- a/rlclientlib/azure_factories.h +++ b/rlclientlib/azure_factories.h @@ -1,5 +1,4 @@ #pragma once - #include "oauth_callback_fn.h" namespace reinforcement_learning @@ -7,4 +6,6 @@ namespace reinforcement_learning void register_azure_factories(); void register_azure_oauth_factories(oauth_callback_t& callback); -} + +using azure_cred_provider_cb_wrapper_t = std::unique_ptr; +} // namespace reinforcement_learning diff --git a/rlclientlib/factory_resolver.cc b/rlclientlib/factory_resolver.cc index e66d45013..393dac1c0 100644 --- a/rlclientlib/factory_resolver.cc +++ b/rlclientlib/factory_resolver.cc @@ -37,6 +37,7 @@ static natural_align::type modelfactory_buf; static natural_align::type senderfactory_buf; static natural_align::type traceloggerfactory_buf; static natural_align::type time_provider_factory_buf; +static natural_align::type azure_cred_provider_factory_buf; // Reference should point to the allocated memory to be initialized by placement new in // factory_initializer::factory_initializer() @@ -45,6 +46,8 @@ model_factory_t& model_factory = (model_factory_t&)(modelfactory_buf); sender_factory_t& sender_factory = (sender_factory_t&)(senderfactory_buf); trace_logger_factory_t& trace_logger_factory = (trace_logger_factory_t&)(traceloggerfactory_buf); time_provider_factory_t& time_provider_factory = (time_provider_factory_t&)(time_provider_factory_buf); +azure_cred_provider_factory_t& azure_cred_provider_factory = + (azure_cred_provider_factory_t&)(azure_cred_provider_factory_buf); factory_initializer::factory_initializer() { @@ -55,6 +58,7 @@ factory_initializer::factory_initializer() new (&sender_factory) sender_factory_t(); new (&trace_logger_factory) trace_logger_factory_t(); new (&time_provider_factory) time_provider_factory_t(); + new (&azure_cred_provider_factory) azure_cred_provider_factory_t(); register_default_factories(); } @@ -72,12 +76,7 @@ factory_initializer::~factory_initializer() } } -void register_default_factories_callback(oauth_callback_t& callback) -{ -#ifdef USE_AZURE_FACTORIES - register_azure_oauth_factories(callback); -#endif -} +void register_default_factories_callback(oauth_callback_t& callback) { register_azure_oauth_factories(callback); } template int model_create( diff --git a/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc b/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc index 35ffb74d7..6e2d0f403 100644 --- a/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc +++ b/rlclientlib/model_mgmt/restapi_data_transport_oauth.cc @@ -22,20 +22,20 @@ namespace reinforcement_learning { namespace model_management { -restapi_data_transport_oauth::restapi_data_transport_oauth( - i_http_client* httpcli, i_trace* trace, oauth_callback_t& callback, std::string scope) - : _httpcli(httpcli), _datasz{0}, _trace{trace}, _headerimpl(callback, std::move(scope)) +restapi_data_transport_oauth::restapi_data_transport_oauth(i_http_client* httpcli, i_trace* trace, + std::unique_ptr&& cred_provider, std::string scope) + : _httpcli(httpcli), _datasz{0}, _trace{trace}, _headerimpl(std::move(cred_provider), std::move(scope)) { } restapi_data_transport_oauth::restapi_data_transport_oauth(std::unique_ptr&& httpcli, - utility::configuration cfg, model_source model_source, i_trace* trace, oauth_callback_t& callback, - std::string scope) + utility::configuration cfg, model_source model_source, i_trace* trace, + std::unique_ptr&& cred_provider, std::string scope) : _httpcli(std::move(httpcli)) , _cfg(std::move(cfg)) , _model_source(model_source) , _datasz{0} , _trace{trace} - , _headerimpl(callback, std::move(scope)) + , _headerimpl(std::move(cred_provider), std::move(scope)) { } diff --git a/rlclientlib/model_mgmt/restapi_data_transport_oauth.h b/rlclientlib/model_mgmt/restapi_data_transport_oauth.h index 02c05caf3..8b055bbde 100644 --- a/rlclientlib/model_mgmt/restapi_data_transport_oauth.h +++ b/rlclientlib/model_mgmt/restapi_data_transport_oauth.h @@ -21,9 +21,11 @@ class restapi_data_transport_oauth : public i_data_transport { public: // Takes the ownership of the i_http_client and delete it at the end of lifetime - restapi_data_transport_oauth(i_http_client* httpcli, i_trace* trace, oauth_callback_t& callback, std::string scope); + restapi_data_transport_oauth(i_http_client* httpcli, i_trace* trace, + std::unique_ptr&& cred_provider, std::string scope); restapi_data_transport_oauth(std::unique_ptr&& httpcli, utility::configuration cfg, - model_source model_source, i_trace* trace, oauth_callback_t& callback, std::string scope); + model_source model_source, i_trace* trace, std::unique_ptr&& cred_provider, + std::string scope); int get_data(model_data& ret, api_status* status) override; diff --git a/rlclientlib/utility/api_header_token.h b/rlclientlib/utility/api_header_token.h index 8edc8d457..0a8d0078b 100644 --- a/rlclientlib/utility/api_header_token.h +++ b/rlclientlib/utility/api_header_token.h @@ -1,9 +1,9 @@ #pragma once #include "api_status.h" +#include "azure_credentials_provider.h" #include "configuration.h" #include "constants.h" -#include "oauth_callback_fn.h" #include "trace_logger.h" #include @@ -53,8 +53,8 @@ template class api_header_token_callback { public: - api_header_token_callback(oauth_callback_t& token_cb, std::string scope) - : _token_callback(token_cb), _scopes{std::move(scope)} + api_header_token_callback(std::unique_ptr&& cred_provider, std::string scope) + : _cred_provider(std::move(cred_provider)), _scopes{std::move(scope)} { } ~api_header_token_callback() = default; @@ -107,7 +107,7 @@ class api_header_token_callback { using namespace std::chrono; system_clock::time_point tp; - RETURN_IF_FAIL(_token_callback(_scopes, _bearer_token, _token_expiry, trace)); + RETURN_IF_FAIL(_cred_provider->get_token(_scopes, _bearer_token, _token_expiry, trace)); if (_bearer_token.empty()) { @@ -123,7 +123,7 @@ class api_header_token_callback private: http_headers::key_type _http_api_header_key_name; std::string _token_type; - oauth_callback_t _token_callback; + std::unique_ptr _cred_provider; std::vector _scopes; std::string _bearer_token; diff --git a/test_tools/example_gen/CMakeLists.txt b/test_tools/example_gen/CMakeLists.txt index 7eb1716e9..2d1a75717 100644 --- a/test_tools/example_gen/CMakeLists.txt +++ b/test_tools/example_gen/CMakeLists.txt @@ -3,3 +3,8 @@ add_executable(example_gen ) target_link_libraries(example_gen PRIVATE Boost::program_options rlclientlib) +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(example_gen PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(example_gen PRIVATE Azure::azure-identity) +endif() diff --git a/test_tools/sender_test/CMakeLists.txt b/test_tools/sender_test/CMakeLists.txt index 1ef260ef1..fe0fd762c 100644 --- a/test_tools/sender_test/CMakeLists.txt +++ b/test_tools/sender_test/CMakeLists.txt @@ -7,3 +7,9 @@ add_executable(sender_test target_include_directories(sender_test PRIVATE $) target_link_libraries(sender_test PRIVATE Boost::program_options rlclientlib) + +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(sender_test PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(sender_test PRIVATE Azure::azure-identity) +endif() diff --git a/unit_test/CMakeLists.txt b/unit_test/CMakeLists.txt index d857d5fd5..4e2bdbc0b 100644 --- a/unit_test/CMakeLists.txt +++ b/unit_test/CMakeLists.txt @@ -86,6 +86,12 @@ if(RL_USE_UBSAN) target_link_options(rltest PRIVATE -fno-sanitize=vptr) endif() +if(RL_LINK_AZURE_LIBS) + target_compile_definitions(rltest PRIVATE LINK_AZURE_LIBS) + find_package(azure-identity-cpp CONFIG REQUIRED) + target_link_libraries(rltest PRIVATE Azure::azure-identity) +endif() + target_link_libraries(rltest PRIVATE rlclientlib From 65ce2a789f7527a49614ef1e9351575b74d263c9 Mon Sep 17 00:00:00 2001 From: James Longo Date: Wed, 14 Aug 2024 16:02:22 -0400 Subject: [PATCH 11/19] fixed formatting added missing include --- rlclientlib/azure_factories.cc | 86 +++++++++++++++++----------------- rlclientlib/azure_factories.h | 2 + 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/rlclientlib/azure_factories.cc b/rlclientlib/azure_factories.cc index 45e62dd42..95bfbdf1a 100644 --- a/rlclientlib/azure_factories.cc +++ b/rlclientlib/azure_factories.cc @@ -1,8 +1,8 @@ #include "azure_factories.h" +#include "azure_credentials_provider.h" #include "constants.h" #include "factory_resolver.h" -#include "azure_credentials_provider.h" #include "logger/event_logger.h" #include "logger/http_transport_client.h" #include "model_mgmt/restapi_data_transport.h" @@ -47,22 +47,21 @@ int interaction_api_sender_create(std::unique_ptr& retval, const u::co int oauth_restapi_data_transport_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& config, i_trace* trace_logger, api_status* status); -int episode_api_sender_oauth_cb_create(oauth_callback_t& callback, - std::unique_ptr& retval, +int episode_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); int observation_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); int interaction_api_sender_oauth_cb_create(oauth_callback_t& callback, std::unique_ptr& retval, const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int oauth_restapi_data_transport_create(std::unique_ptr& retval, - const u::configuration& config, i_trace* trace_logger, api_status* status); -int episode_api_sender_oauth_create(std::unique_ptr& retval, - const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int observation_api_sender_oauth_create(std::unique_ptr& retval, - const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); -int interaction_api_sender_oauth_create(std::unique_ptr& retval, - const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int oauth_restapi_data_transport_create(std::unique_ptr& retval, const u::configuration& config, + i_trace* trace_logger, api_status* status); +int episode_api_sender_oauth_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int observation_api_sender_oauth_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); +int interaction_api_sender_oauth_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status); int azure_default_cred_provider_create(std::unique_ptr& retval, const u::configuration& cfg, i_trace* trace_logger, api_status* status); @@ -73,7 +72,6 @@ int azure_azurecli_cred_provider_create(std::unique_ptr& retval, const u::configuration& cfg, i_trace* trace_logger, api_status* status); - void register_azure_factories() { data_transport_factory.register_type(value::AZURE_STORAGE_BLOB, restapi_data_transport_create); @@ -92,9 +90,12 @@ void register_azure_factories() // register default azure credentials provider factories azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_DEFAULT, azure_default_cred_provider_create); - azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_MANAGEDIDENTITY, azure_managed_cred_provider_create); - azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_AZURECLI, azure_azurecli_cred_provider_create); - azure_cred_provider_factory.register_type(value::AZURE_OAUTH_CREDENTIALS_CLIENTSECRET, azure_clientsecret_cred_provider_create); + azure_cred_provider_factory.register_type( + value::AZURE_OAUTH_CREDENTIALS_MANAGEDIDENTITY, azure_managed_cred_provider_create); + azure_cred_provider_factory.register_type( + value::AZURE_OAUTH_CREDENTIALS_AZURECLI, azure_azurecli_cred_provider_create); + azure_cred_provider_factory.register_type( + value::AZURE_OAUTH_CREDENTIALS_CLIENTSECRET, azure_clientsecret_cred_provider_create); // register built-in azure oauth factories data_transport_factory.register_type(value::HTTP_MODEL_DATA_OAUTH_AZ, oauth_restapi_data_transport_create); @@ -197,19 +198,18 @@ int create_apim_http_api_oauth_cb_sender(oauth_callback_t& callback, std::unique return error_code::success; } -int create_apim_http_api_oauth_sender(std::unique_ptr& retval, - const u::configuration& cfg, const char* api_host, int tasks_limit, int max_http_retries, - std::chrono::milliseconds max_http_retry_duration, error_callback_fn* error_cb, i_trace* trace_logger, - api_status* status) +int create_apim_http_api_oauth_sender(std::unique_ptr& retval, const u::configuration& cfg, + const char* api_host, int tasks_limit, int max_http_retries, std::chrono::milliseconds max_http_retry_duration, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { std::unique_ptr cred_provider; int ret_code = get_azure_credential_provider(cred_provider, cfg, trace_logger, status); if (ret_code != error_code::success) { return ret_code; } i_http_client* client = nullptr; RETURN_IF_FAIL(create_http_client(api_host, cfg, &client, status)); - retval.reset( - new http_transport_client>(client, tasks_limit, max_http_retries, - max_http_retry_duration, trace_logger, error_cb, std::move(cred_provider), "https://eventhubs.azure.net//.default")); + retval.reset(new http_transport_client>(client, tasks_limit, + max_http_retries, max_http_retry_duration, trace_logger, error_cb, std::move(cred_provider), + "https://eventhubs.azure.net//.default")); return error_code::success; } @@ -322,8 +322,8 @@ int interaction_api_sender_oauth_cb_create(oauth_callback_t& callback, std::uniq trace_logger, status); } -int oauth_restapi_data_transport_create(std::unique_ptr& retval, - const u::configuration& config, i_trace* trace_logger, api_status* status) +int oauth_restapi_data_transport_create(std::unique_ptr& retval, const u::configuration& config, + i_trace* trace_logger, api_status* status) { std::unique_ptr cred_provider; int ret_code = get_azure_credential_provider(cred_provider, config, trace_logger, status); @@ -337,38 +337,38 @@ int oauth_restapi_data_transport_create(std::unique_ptr& re return error_code::success; } -int episode_api_sender_oauth_create(std::unique_ptr& retval, - const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +int episode_api_sender_oauth_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::EPISODE_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_oauth_sender(retval, cfg, api_host, - cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRIES, 4), cfg.get_int(name::EPISODE_APIM_TASKS_LIMIT, 4), + return create_apim_http_api_oauth_sender(retval, cfg, api_host, cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRIES, 4), + cfg.get_int(name::EPISODE_APIM_TASKS_LIMIT, 4), std::chrono::milliseconds(cfg.get_int(name::EPISODE_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); } -int observation_api_sender_oauth_create(std::unique_ptr& retval, - const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +int observation_api_sender_oauth_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::OBSERVATION_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_oauth_sender(retval, cfg, api_host, - cfg.get_int(name::OBSERVATION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRIES, 4), + return create_apim_http_api_oauth_sender(retval, cfg, api_host, cfg.get_int(name::OBSERVATION_APIM_TASKS_LIMIT, 16), + cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRIES, 4), std::chrono::milliseconds(cfg.get_int(name::OBSERVATION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); } -int interaction_api_sender_oauth_create(std::unique_ptr& retval, - const u::configuration& cfg, error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) +int interaction_api_sender_oauth_create(std::unique_ptr& retval, const u::configuration& cfg, + error_callback_fn* error_cb, i_trace* trace_logger, api_status* status) { const auto* const api_host = cfg.get(name::INTERACTION_HTTP_API_HOST, "localhost:8080"); - return create_apim_http_api_oauth_sender(retval, cfg, api_host, - cfg.get_int(name::INTERACTION_APIM_TASKS_LIMIT, 16), cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRIES, 4), + return create_apim_http_api_oauth_sender(retval, cfg, api_host, cfg.get_int(name::INTERACTION_APIM_TASKS_LIMIT, 16), + cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRIES, 4), std::chrono::milliseconds(cfg.get_int(name::INTERACTION_APIM_MAX_HTTP_RETRY_DURATION_MS, 3600000)), error_cb, trace_logger, status); } -int azure_default_cred_provider_create(std::unique_ptr& retval, const u::configuration& cfg, - i_trace* trace_logger, api_status* status) +int azure_default_cred_provider_create(std::unique_ptr& retval, + const u::configuration& cfg, i_trace* trace_logger, api_status* status) { int ret_code = error_code::not_supported; #ifdef LINK_AZURE_LIBS @@ -379,7 +379,7 @@ int azure_default_cred_provider_create(std::unique_ptr& retval, - const u::configuration& cfg, i_trace* trace_logger, api_status* status) + const u::configuration& cfg, i_trace* trace_logger, api_status* status) { int ret_code = error_code::not_supported; #ifdef LINK_AZURE_LIBS @@ -391,7 +391,7 @@ int azure_managed_cred_provider_create(std::unique_ptr& retval, - const u::configuration& cfg, i_trace* trace_logger, api_status* status) + const u::configuration& cfg, i_trace* trace_logger, api_status* status) { int ret_code = error_code::not_supported; #ifdef LINK_AZURE_LIBS @@ -405,15 +405,15 @@ int azure_azurecli_cred_provider_create(std::unique_ptr& retval, - const u::configuration& cfg, i_trace* trace_logger, api_status* status) + const u::configuration& cfg, i_trace* trace_logger, api_status* status) { int ret_code = error_code::not_supported; #ifdef LINK_AZURE_LIBS const auto client_id = cfg.get(name::AZURE_OAUTH_CREDENTIAL_CLIENTID, ""); const auto tenant_id = cfg.get(name::AZURE_OAUTH_CREDENTIAL_TENANTID, ""); const auto secret = cfg.get(name::AZURE_OAUTH_CREDENTIAL_CLIENTSECRET, ""); - retval = - std::make_unique>(tenant_id, client_id, secret); + retval = std::make_unique>( + tenant_id, client_id, secret); ret_code = error_code::success; #endif return ret_code; diff --git a/rlclientlib/azure_factories.h b/rlclientlib/azure_factories.h index f6adeb53d..e2d62132d 100644 --- a/rlclientlib/azure_factories.h +++ b/rlclientlib/azure_factories.h @@ -1,6 +1,8 @@ #pragma once #include "oauth_callback_fn.h" +#include + namespace reinforcement_learning { void register_azure_factories(); From 70894c8e641d5c30ddbf7780d116ce49f8b44a1f Mon Sep 17 00:00:00 2001 From: James Longo Date: Wed, 21 Aug 2024 14:26:26 -0400 Subject: [PATCH 12/19] added missing call to azure_cred_provider_factory_t dtor in factory_initializer --- rlclientlib/factory_resolver.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/rlclientlib/factory_resolver.cc b/rlclientlib/factory_resolver.cc index 393dac1c0..2ecea80fd 100644 --- a/rlclientlib/factory_resolver.cc +++ b/rlclientlib/factory_resolver.cc @@ -73,6 +73,7 @@ factory_initializer::~factory_initializer() (&sender_factory)->~sender_factory_t(); (&trace_logger_factory)->~trace_logger_factory_t(); (&time_provider_factory)->~time_provider_factory_t(); + (&azure_cred_provider_factory)->~azure_cred_provider_factory_t(); } } From ae91b1c0093412c18a9e6f79aa60046e95498314 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 10:19:42 -0400 Subject: [PATCH 13/19] update the job name for macos so that we don't need to change the status check names (start with asan) --- .github/workflows/asan.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 8a291819f..75d7eedc4 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -25,7 +25,7 @@ jobs: # UBSan not supported by MSVC on Windows - { os: windows-latest, preset: vcpkg-ubsan-debug } runs-on: ${{ matrix.os }} - name: asan.${{ matrix.os }}.${{ matrix.preset }} + name: asan.${{ matrix.os == 'macos-13' ? 'macos' : matrix.os }}.${{ matrix.preset }} env: UBSAN_OPTIONS: "print_stacktrace=1" From e03dd80233ff67090410ca6307bcfa5722cbd580 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 10:53:04 -0400 Subject: [PATCH 14/19] fix using ternary operator expression --- .github/workflows/asan.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 75d7eedc4..b20c2ea5b 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -25,7 +25,7 @@ jobs: # UBSan not supported by MSVC on Windows - { os: windows-latest, preset: vcpkg-ubsan-debug } runs-on: ${{ matrix.os }} - name: asan.${{ matrix.os == 'macos-13' ? 'macos' : matrix.os }}.${{ matrix.preset }} + name: asan.${{ matrix.os == 'macos-13' && 'macos' || matrix.os }}.${{ matrix.preset }} env: UBSAN_OPTIONS: "print_stacktrace=1" From 97c488c9527ec1983471a4ed5c64aa12b53e4629 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 11:11:09 -0400 Subject: [PATCH 15/19] updated the vcpkg build to align the macos package names with the status checks --- .github/workflows/vcpkg_build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index 7d2b456fd..dc08800a0 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -18,7 +18,7 @@ concurrency: jobs: job: - name: ${{ matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} + name: ${{ matrix.os == 'macos-13' && 'macos' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -39,7 +39,7 @@ jobs: cache-name: vcpkg-cache with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" + key: ${{ matrix.os == 'macos-13' && 'macos' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" - uses: lukka/run-vcpkg@v10 with: vcpkgDirectory: '${{ github.workspace }}/ext_libs/vcpkg' From f24f886bdcbd05f463f754730bb458e2c5767d16 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 12:38:57 -0400 Subject: [PATCH 16/19] the job name should 'macos-latest' --- .github/workflows/asan.yml | 2 +- .github/workflows/vcpkg_build.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index b20c2ea5b..01b18c4da 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -25,7 +25,7 @@ jobs: # UBSan not supported by MSVC on Windows - { os: windows-latest, preset: vcpkg-ubsan-debug } runs-on: ${{ matrix.os }} - name: asan.${{ matrix.os == 'macos-13' && 'macos' || matrix.os }}.${{ matrix.preset }} + name: asan.${{ matrix.os == 'macos-13' && 'macos-latest' || matrix.os }}.${{ matrix.preset }} env: UBSAN_OPTIONS: "print_stacktrace=1" diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index dc08800a0..f21a0ed93 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -18,7 +18,7 @@ concurrency: jobs: job: - name: ${{ matrix.os == 'macos-13' && 'macos' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} + name: ${{ matrix.os == 'macos-13' && 'macos-latest' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -39,7 +39,7 @@ jobs: cache-name: vcpkg-cache with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ matrix.os == 'macos-13' && 'macos' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" + key: ${{ matrix.os == 'macos-13' && 'macos-latest' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" - uses: lukka/run-vcpkg@v10 with: vcpkgDirectory: '${{ github.workspace }}/ext_libs/vcpkg' From 6f3ff3d221474937103e2233fbf52e4f7499b3a9 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 14:49:47 -0400 Subject: [PATCH 17/19] updated another macos-13 os tag to macos-latest --- .github/workflows/dotnet_nugets.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dotnet_nugets.yml b/.github/workflows/dotnet_nugets.yml index 44116260b..433e914e3 100644 --- a/.github/workflows/dotnet_nugets.yml +++ b/.github/workflows/dotnet_nugets.yml @@ -52,7 +52,7 @@ jobs: uses: actions/cache@v3 with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ matrix.config.os }}-build-${{ matrix.config.vcpkg_target_triplet }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }} + key: ${{ matrix.config.os == 'macos-13' && 'macos-latest' || matrix.config.os }}-build-${{ matrix.config.vcpkg_target_triplet }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }} - name: Configure .NET Core run: > From ae768f468a819a2e9aa23e4d7e7594af43cdca5d Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 15:07:45 -0400 Subject: [PATCH 18/19] changed macos-13 exact match to startsWith added missing macos-13 transforms --- .github/workflows/asan.yml | 2 +- .github/workflows/build_vw_bp.yml | 4 ++-- .github/workflows/dotnet_nugets.yml | 2 +- .github/workflows/vcpkg_build.yml | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 01b18c4da..800bab7f6 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -25,7 +25,7 @@ jobs: # UBSan not supported by MSVC on Windows - { os: windows-latest, preset: vcpkg-ubsan-debug } runs-on: ${{ matrix.os }} - name: asan.${{ matrix.os == 'macos-13' && 'macos-latest' || matrix.os }}.${{ matrix.preset }} + name: asan.${{ startsWith(matrix.os, 'macos-13') && 'macos-latest' || matrix.os }}.${{ matrix.preset }} env: UBSAN_OPTIONS: "print_stacktrace=1" diff --git a/.github/workflows/build_vw_bp.yml b/.github/workflows/build_vw_bp.yml index 53c625eab..ac652b568 100644 --- a/.github/workflows/build_vw_bp.yml +++ b/.github/workflows/build_vw_bp.yml @@ -18,7 +18,7 @@ concurrency: jobs: build-binary-parser: - name: binary-parser-${{ matrix.build.build_type }}-${{ matrix.config.os }} + name: binary-parser-${{ matrix.build.build_type }}-${{ startsWith(matrix.config.os, 'macos-13') && 'macos-latest' || matrix.config.os }} runs-on: ${{ matrix.config.os }} strategy: fail-fast: false @@ -51,7 +51,7 @@ jobs: cache-name: vcpkg-cache with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ matrix.config.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" + key: ${{ startsWith(matrix.config.os, 'macos-13') && 'macos-latest' || matrix.config.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" - uses: lukka/run-vcpkg@v10 with: vcpkgDirectory: '${{ github.workspace }}/ext_libs/vcpkg' diff --git a/.github/workflows/dotnet_nugets.yml b/.github/workflows/dotnet_nugets.yml index 433e914e3..e90facf86 100644 --- a/.github/workflows/dotnet_nugets.yml +++ b/.github/workflows/dotnet_nugets.yml @@ -52,7 +52,7 @@ jobs: uses: actions/cache@v3 with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ matrix.config.os == 'macos-13' && 'macos-latest' || matrix.config.os }}-build-${{ matrix.config.vcpkg_target_triplet }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }} + key: ${{ startsWith(matrix.config.os, 'macos-13') && 'macos-latest' || matrix.config.os }}-build-${{ matrix.config.vcpkg_target_triplet }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }} - name: Configure .NET Core run: > diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index f21a0ed93..67ac6bdb5 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -18,7 +18,7 @@ concurrency: jobs: job: - name: ${{ matrix.os == 'macos-13' && 'macos-latest' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} + name: ${{ startsWith(matrix.os, 'macos-13') && 'macos-latest' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -39,7 +39,7 @@ jobs: cache-name: vcpkg-cache with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ matrix.os == 'macos-13' && 'macos-latest' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" + key: ${{ startsWith(matrix.os, 'macos-13') && 'macos-latest' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" - uses: lukka/run-vcpkg@v10 with: vcpkgDirectory: '${{ github.workspace }}/ext_libs/vcpkg' From bbd92c1e7543930e456cd755229b79e7b0d3c396 Mon Sep 17 00:00:00 2001 From: James Longo Date: Thu, 22 Aug 2024 15:43:55 -0400 Subject: [PATCH 19/19] corrected startsWith 'macos-13' to 'macos' added job names to test-nuget and build-nuget-dotnet to match the status checks --- .github/workflows/asan.yml | 2 +- .github/workflows/build_vw_bp.yml | 4 ++-- .github/workflows/dotnet_nugets.yml | 4 +++- .github/workflows/vcpkg_build.yml | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/asan.yml b/.github/workflows/asan.yml index 800bab7f6..a07c32e55 100644 --- a/.github/workflows/asan.yml +++ b/.github/workflows/asan.yml @@ -25,7 +25,7 @@ jobs: # UBSan not supported by MSVC on Windows - { os: windows-latest, preset: vcpkg-ubsan-debug } runs-on: ${{ matrix.os }} - name: asan.${{ startsWith(matrix.os, 'macos-13') && 'macos-latest' || matrix.os }}.${{ matrix.preset }} + name: asan.${{ startsWith(matrix.os, 'macos') && 'macos-latest' || matrix.os }}.${{ matrix.preset }} env: UBSAN_OPTIONS: "print_stacktrace=1" diff --git a/.github/workflows/build_vw_bp.yml b/.github/workflows/build_vw_bp.yml index ac652b568..5d549fe32 100644 --- a/.github/workflows/build_vw_bp.yml +++ b/.github/workflows/build_vw_bp.yml @@ -18,7 +18,7 @@ concurrency: jobs: build-binary-parser: - name: binary-parser-${{ matrix.build.build_type }}-${{ startsWith(matrix.config.os, 'macos-13') && 'macos-latest' || matrix.config.os }} + name: binary-parser-${{ matrix.build.build_type }}-${{ startsWith(matrix.config.os, 'macos') && 'macos-latest' || matrix.config.os }} runs-on: ${{ matrix.config.os }} strategy: fail-fast: false @@ -51,7 +51,7 @@ jobs: cache-name: vcpkg-cache with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ startsWith(matrix.config.os, 'macos-13') && 'macos-latest' || matrix.config.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" + key: ${{ startsWith(matrix.config.os, 'macos') && 'macos-latest' || matrix.config.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" - uses: lukka/run-vcpkg@v10 with: vcpkgDirectory: '${{ github.workspace }}/ext_libs/vcpkg' diff --git a/.github/workflows/dotnet_nugets.yml b/.github/workflows/dotnet_nugets.yml index e90facf86..323f4257a 100644 --- a/.github/workflows/dotnet_nugets.yml +++ b/.github/workflows/dotnet_nugets.yml @@ -31,6 +31,7 @@ jobs: - { os: "ubuntu-latest", runtime_id: "linux-x64", vcpkg_target_triplet: "x64-linux" } - { os: "macos-13", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" } runs-on: ${{matrix.config.os}} + name: build-nuget-dotnet (${{ startsWith(matrix.config.os, 'macos') && 'macos-latest' || matrix.config.os }}, ${{ matrix.config.runtime_id }}, ${{ matrix.config.vcpkg_target_triplet }}) steps: - uses: actions/checkout@v2 - run: | @@ -52,7 +53,7 @@ jobs: uses: actions/cache@v3 with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ startsWith(matrix.config.os, 'macos-13') && 'macos-latest' || matrix.config.os }}-build-${{ matrix.config.vcpkg_target_triplet }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }} + key: ${{ startsWith(matrix.config.os, 'macos') && 'macos-latest' || matrix.config.os }}-build-${{ matrix.config.vcpkg_target_triplet }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }} - name: Configure .NET Core run: > @@ -164,6 +165,7 @@ jobs: - { os: "ubuntu-latest", runtime_id: "linux-x64" } - { os: "macos-13", runtime_id: "osx-x64" } runs-on: ${{matrix.config.os}} + name: test-nuget (${{ startsWith(matrix.config.os, 'macos') && 'macos-latest' || matrix.config.os }}, ${{ matrix.config.runtime_id }}) steps: - uses: actions/checkout@v2 - name: Update git tags diff --git a/.github/workflows/vcpkg_build.yml b/.github/workflows/vcpkg_build.yml index 67ac6bdb5..6535296cf 100644 --- a/.github/workflows/vcpkg_build.yml +++ b/.github/workflows/vcpkg_build.yml @@ -18,7 +18,7 @@ concurrency: jobs: job: - name: ${{ startsWith(matrix.os, 'macos-13') && 'macos-latest' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} + name: ${{ startsWith(matrix.os, 'macos') && 'macos-latest' || matrix.os }}-${{ matrix.preset }}-${{ github.workflow }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -39,7 +39,7 @@ jobs: cache-name: vcpkg-cache with: path: ${{ env.VCPKG_DEFAULT_BINARY_CACHE }}/* - key: ${{ startsWith(matrix.os, 'macos-13') && 'macos-latest' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" + key: ${{ startsWith(matrix.os, 'macos') && 'macos-latest' || matrix.os }}-build-${{ env.cache-name }}-${{ hashFiles('vcpkg.json') }}-${{ env.VCPKG_COMMIT }}" - uses: lukka/run-vcpkg@v10 with: vcpkgDirectory: '${{ github.workspace }}/ext_libs/vcpkg'