Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rlclientlib and rl.net to support azure credential hooks #609

Merged
merged 19 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
63216c9
added azure_credentials_provider.h to the public headers
v-jameslongo Aug 8, 2024
484c5e5
added memory and azure include to the azure_credentials_provider
v-jameslongo Aug 9, 2024
cd6e87a
added trace logger to the azure credential providers
v-jameslongo Aug 9, 2024
ebfcfe7
removed trace from lock scope
v-jameslongo Aug 9, 2024
14303e8
added set(CMAKE_DEBUG_POSTFIX "") to ensure VS builds have the same l…
v-jameslongo Aug 9, 2024
fc2fc17
updated formatting (according to tidy)
v-jameslongo Aug 9, 2024
060ce8a
moved find dotnet-t4 to FindDotnet.cmake and check it for windows onl…
v-jameslongo Aug 12, 2024
ac470e6
the x64 macos image was changed to macos-latest-large. see (https://g…
v-jameslongo Aug 12, 2024
56ff878
for now, moving to build runner macos-13 supporting x64
v-jameslongo Aug 12, 2024
f05a552
updates to provide default azure credential implementations via the f…
v-jameslongo Aug 14, 2024
65ce2a7
fixed formatting
v-jameslongo Aug 14, 2024
70894c8
added missing call to azure_cred_provider_factory_t dtor in factory_…
v-jameslongo Aug 21, 2024
ae91b1c
update the job name for macos so that we don't need to change the sta…
v-jameslongo Aug 22, 2024
e03dd80
fix using ternary operator expression
v-jameslongo Aug 22, 2024
97c488c
updated the vcpkg build to align the macos package names with the sta…
v-jameslongo Aug 22, 2024
f24f886
the job name should 'macos-latest'
v-jameslongo Aug 22, 2024
6f3ff3d
updated another macos-13 os tag to macos-latest
v-jameslongo Aug 22, 2024
ae768f4
changed macos-13 exact match to startsWith
v-jameslongo Aug 22, 2024
bbd92c1
corrected startsWith 'macos-13' to 'macos'
v-jameslongo Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/asan.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
strategy:
fail-fast: false
matrix:
#os: [windows-latest, ubuntu-latest, macos-latest]
os: [ubuntu-latest, macos-latest] # Temporarily remove windows asan
#os: [windows-latest, ubuntu-latest, macos-13]
os: [ubuntu-latest, macos-13] # Temporarily remove windows asan
preset: [vcpkg-asan-debug, vcpkg-ubsan-debug]
exclude:
# UBSan not supported by MSVC on Windows
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_rlclientlib.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
build-macos:
# Mac build doesn't have any additional features enabled
name: rlclientlib-${{ matrix.build_type }}-macos-latest
runs-on: macos-latest
runs-on: macos-13
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build_vw_bp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
config:
- { os: "windows-latest", vcpkg_target_triplet: "x64-windows-static" }
- { os: "ubuntu-latest", vcpkg_target_triplet: "x64-linux" }
- { os: "macos-latest", vcpkg_target_triplet: "x64-osx" }
- { os: "macos-13", vcpkg_target_triplet: "x64-osx" }
build:
# Set the appropriate static runtime for MSVC on Windows
# CMake ignores the CMAKE_MSVC_RUNTIME_LIBRARY option on other platforms
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/dotnet_nugets.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
config:
- { os: "windows-latest", runtime_id: "win-x64", vcpkg_target_triplet: "x64-windows-static" }
- { os: "ubuntu-latest", runtime_id: "linux-x64", vcpkg_target_triplet: "x64-linux" }
- { os: "macos-latest", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" }
- { os: "macos-13", runtime_id: "osx-x64", vcpkg_target_triplet: "x64-osx" }
runs-on: ${{matrix.config.os}}
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -162,7 +162,7 @@ jobs:
config:
- { os: "windows-latest", runtime_id: "win-x64" }
- { os: "ubuntu-latest", runtime_id: "linux-x64" }
- { os: "macos-latest", runtime_id: "osx-x64" }
- { os: "macos-13", runtime_id: "osx-x64" }
runs-on: ${{matrix.config.os}}
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/vcpkg_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, macos-13, windows-latest]
preset: [vcpkg-debug, vcpkg-release]
steps:
- uses: actions/checkout@v3
Expand Down
7 changes: 7 additions & 0 deletions bindings/cs/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/")
include(FindDotnet)

# note: this change was made since building with Ninja does not add suffixes
# but, using the VS generator does. rl.net uses dllimport to load rlnetnative.
# This is a workaround to make sure the correct dll is used.
if (WIN32 AND CMAKE_GENERATOR MATCHES "Visual Studio")
set(CMAKE_DEBUG_POSTFIX "")
endif()

add_subdirectory(rl.net.native)
add_subdirectory(rl.net)
add_subdirectory(rl.net.cli)
Expand Down
16 changes: 13 additions & 3 deletions bindings/cs/common/codegen/TextTemplate.targets
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
</ItemGroup>

<PropertyGroup>
<T4Command>dotnet t4</T4Command>
<T4Command Condition="'$(OS)' == 'Windows_NT'">t4</T4Command>
<!-- if this fails you may need to install dotnet-t4 as follows -->
<!-- dotnet tool install -g dotnet-t4 -->
<T4Command Condition="'$(T4Command)' == ''">t4</T4Command>
</PropertyGroup>

<Target Name="TextTemplateTransform" BeforeTargets="BeforeBuild">
Expand All @@ -17,14 +18,23 @@
<MakeDir Directories="$(IntermediateOutputPath)" />
<PropertyGroup>
<Parameters>@(TextTransformParameter -> '-a !!%(Identity)!%(Value)', ' ')</Parameters>
<TaskCommand>$(T4Command) "%(TextTemplate.Identity)" -o "$([System.IO.Path]::Combine($(IntermediateOutputPath),%(TextTemplate.FileName).T4Generated.cs))" $(Parameters)</TaskCommand>
<TaskCommand Condition="'$(OS)' != 'Windows_NT'">$(T4Command) "%(TextTemplate.Identity)" -o "$([System.IO.Path]::Combine($(IntermediateOutputPath),%(TextTemplate.FileName).T4Generated.cs))" $(Parameters)</TaskCommand>
<TaskCommand Condition="'$(OS)' == 'Windows_NT'">$(T4Command) "%(TextTemplate.Identity)" -out "$([System.IO.Path]::Combine($(IntermediateOutputPath),%(TextTemplate.FileName).T4Generated.cs))" $(Parameters)</TaskCommand>
</PropertyGroup>

<Exec WorkingDirectory="$(ProjectDir)" Command="$(TaskCommand)"/>

<ItemGroup>
<Compile Include="$(IntermediateOutputPath)\*.T4Generated.cs"/>
</ItemGroup>

<OnError ExecuteTargets="HandleT4Failure" />
</Target>

<Target Name="HandleT4Failure">
<Message Text="T4 template failed to transform: %(TextTemplate.Identity); check if you have dotnet-t4 installed. if not, install it with: dotnet tool install -g dotnet-t4" Importance="high"/>
<Error Text="T4 template failed to transform: %(TextTemplate.Identity)" />
<Error Text="check if you have dotnet-t4 installed. if not, install it with: dotnet tool install -g dotnet-t4" />
</Target>

<Target Name="TextTemplateClean" AfterTargets="Clean">
Expand Down
9 changes: 9 additions & 0 deletions bindings/cs/rl.net.native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ set(rl_net_native_HEADERS
rl.net.slot_ranking.h
)

if(RL_LINK_AZURE_LIBS)
list(APPEND rl_net_native_HEADERS
rl.net.azure_factories.h
)
list(APPEND rl_net_native_SOURCES
rl.net.azure_factories.cc
)
endif()

source_group("Sources" FILES ${rl_net_native_SOURCES})
source_group("Headers" FILES ${rl_net_native_HEADERS})

Expand Down
46 changes: 46 additions & 0 deletions bindings/cs/rl.net.native/rl.net.azure_factories.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#include "rl.net.azure_factories.h"

namespace rl_net_native
{
rl_net_native::azure_factory_oauth_callback_t g_oauth_callback = nullptr;
rl_net_native::azure_factory_oauth_callback_complete_t g_oauth_callback_complete = nullptr;

static int azure_factory_oauth_callback(const std::vector<std::string>& scopes, std::string& oauth_token,
std::chrono::system_clock::time_point& token_expiry, reinforcement_learning::i_trace* trace)
{
if (g_oauth_callback == nullptr) { return reinforcement_learning::error_code::invalid_argument; }
// create a null terminated array of scope string pointers.
// these are pointers are readonly and owned by the caller.
// since we create the vector of n elements, we can guarantee
// the last element is null.
std::vector<const char*> native_scopes(scopes.size() + 1);
for (int i = 0; i < scopes.size(); ++i) { native_scopes[i] = scopes[i].c_str(); }
// we expect to get a pointer to a null terminated string.
// it's allocated on the managed heap, so we can't free it here
// instead, we will pass it back to the managed code after we copy it.
char* oauth_token_ptr = nullptr;
std::int64_t expiryUnixTime = 0;
auto ret = g_oauth_callback(native_scopes.data(), &oauth_token_ptr, &expiryUnixTime);
if (ret == reinforcement_learning::error_code::success)
{
TRACE_DEBUG(trace, "rl_net_native::azure_factory_oauth_callback: successfully retrieved token");
oauth_token = oauth_token_ptr;
token_expiry = std::chrono::system_clock::from_time_t(expiryUnixTime);
}
else { TRACE_ERROR(trace, "rl_net_native::azure_factory_oauth_callback: failed to retrieve token"); }
g_oauth_callback_complete(oauth_token_ptr, reinforcement_learning::error_code::success);
return ret;
}
} // namespace rl_net_native

API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback,
rl_net_native::azure_factory_oauth_callback_complete_t completion)
{
if (rl_net_native::g_oauth_callback == nullptr)
{
rl_net_native::g_oauth_callback = callback;
reinforcement_learning::register_default_factories_callback(
reinforcement_learning::oauth_callback_t{rl_net_native::azure_factory_oauth_callback});
rl_net_native::g_oauth_callback_complete = completion;
}
}
35 changes: 35 additions & 0 deletions bindings/cs/rl.net.native/rl.net.azure_factories.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include "factory_resolver.h"
#include "rl.net.native.h"

#include <cstddef>

namespace rl_net_native
{
// Callback function to be called by the C# code to get the OAuth token
// The callback function should return 0 on success and non-zero on failure (see reinforcement_learning::error_code)
// prototype - azure_factory_oauth_callback(const char** scopes, char** token, std::int64_t* expiryUnixTime)
// scopes - null terminated array of scope strings
// token - out pointer to a null terminated string allocated on the managed heap
// expiryUnixTime - pointer to the expiry time of the token in Unix time
typedef int (*azure_factory_oauth_callback_t)(const char**, char**, std::int64_t*);

// Callback function to be called by the C# code to signal the completion of the OAuth token request
// this provides the C# code with the opportunity to free the memory allocated for the token
// tokenStringToFree - pointer to the token string than needs to be freed
// errorCode - for future use
typedef void (*azure_factory_oauth_callback_complete_t)(char*, int);
} // namespace rl_net_native

extern "C"
{
// Register the callback function to be called by the C++ code to get the OAuth token
// both the callback and completion functions must be registered before any calls to the Azure factories are made
// typically, this function should be called once during the initialization of the application
// callback - the callback function to be called by the C++ code to get the OAuth token
// completion - the callback function to be called by the C++ code to signal the completion of the OAuth token
// request
API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback,
rl_net_native::azure_factory_oauth_callback_complete_t completion);
}
12 changes: 6 additions & 6 deletions bindings/cs/rl.net/ApiStatus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@
namespace Rl.Net {
public sealed class ApiStatus : NativeObject<ApiStatus>
{
[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
private static extern IntPtr CreateApiStatus();

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
private static extern void DeleteApiStatus(IntPtr config);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
private static extern IntPtr GetApiStatusErrorMessage(IntPtr status);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
private static extern int GetApiStatusErrorCode(IntPtr status);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
private static extern void UpdateApiStatusSafe(IntPtr status, int error_code, IntPtr message);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
private static extern void ClearApiStatusSafe(IntPtr status);

public ApiStatus() : base(new New<ApiStatus>(CreateApiStatus), new Delete<ApiStatus>(DeleteApiStatus))
Expand Down
24 changes: 12 additions & 12 deletions bindings/cs/rl.net/CALoop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ namespace Native
// The publics in this class are just a verbose, but jittably-efficient way of enabling overriding a native invocation
internal static partial class NativeMethods
{
[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
public static extern IntPtr CreateCALoop(IntPtr config, IntPtr factoryContext);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
public static extern void DeleteCALoop(IntPtr caLoop);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
public static extern int CALoopInit(IntPtr caLoop, IntPtr apiStatus);

[DllImport("rlnetnative", EntryPoint = "CALoopRequestContinuousAction")]
[DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopRequestContinuousAction")]
private static extern int CALoopRequestContinuousActionNative(IntPtr caLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, IntPtr continuousActionResponse, IntPtr apiStatus);

internal static Func<IntPtr, IntPtr, IntPtr, int, IntPtr, IntPtr, int> CALoopRequestContinuousActionOverride { get; set; }
Expand All @@ -34,7 +34,7 @@ public static int CALoopRequestContinuousAction(IntPtr caLoop, IntPtr eventId, I
return CALoopRequestContinuousActionNative(caLoop, eventId, contextJson, contextJsonSize, continuousActionResponse, apiStatus);
}

[DllImport("rlnetnative", EntryPoint = "CALoopRequestContinuousActionWithFlags")]
[DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopRequestContinuousActionWithFlags")]
private static extern int CALoopRequestContinuousActionWithFlagsNative(IntPtr caLoop, IntPtr eventId, IntPtr contextJson, int contextJsonSize, uint flags, IntPtr continuousActionResponse, IntPtr apiStatus);

internal static Func<IntPtr, IntPtr, IntPtr, int, uint, IntPtr, IntPtr, int> CALoopRequestContinuousActionWithFlagsOverride { get; set; }
Expand All @@ -49,7 +49,7 @@ public static int CALoopRequestContinuousActionWithFlags(IntPtr caLoop, IntPtr e
return CALoopRequestContinuousActionWithFlagsNative(caLoop, eventId, contextJson, contextJsonSize, flags, continuousActionResponse, apiStatus);
}

[DllImport("rlnetnative", EntryPoint = "CALoopReportActionTaken")]
[DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportActionTaken")]
private static extern int CALoopReportActionTakenNative(IntPtr caLoop, IntPtr eventId, IntPtr apiStatus);

internal static Func<IntPtr, IntPtr, IntPtr, int> CALoopReportActionTakenOverride { get; set; }
Expand All @@ -64,7 +64,7 @@ public static int CALoopReportActionTaken(IntPtr caLoop, IntPtr eventId, IntPtr
return CALoopReportActionTakenNative(caLoop, eventId, apiStatus);
}

[DllImport("rlnetnative", EntryPoint = "CALoopReportActionMultiIdTaken")]
[DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportActionMultiIdTaken")]
private static extern int CALoopReportActionTakenMultiIdNative(IntPtr caLoop, IntPtr primaryId, IntPtr secondaryId, IntPtr apiStatus);

internal static Func<IntPtr, IntPtr, IntPtr, IntPtr, int> CALoopReportActionTakenMultiIdOverride { get; set; }
Expand All @@ -79,7 +79,7 @@ public static int CALoopReportActionMultiIdTaken(IntPtr caLoop, IntPtr primaryId
return CALoopReportActionTakenMultiIdNative(caLoop, primaryId, secondaryId, apiStatus);
}

[DllImport("rlnetnative", EntryPoint = "CALoopReportOutcomeF")]
[DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportOutcomeF")]
private static extern int CALoopReportOutcomeFNative(IntPtr caLoop, IntPtr eventId, float outcome, IntPtr apiStatus);

internal static Func<IntPtr, IntPtr, float, IntPtr, int> CALoopReportOutcomeFOverride { get; set; }
Expand All @@ -94,7 +94,7 @@ public static int CALoopReportOutcomeF(IntPtr caLoop, IntPtr eventId, float outc
return CALoopReportOutcomeFNative(caLoop, eventId, outcome, apiStatus);
}

[DllImport("rlnetnative", EntryPoint = "CALoopReportOutcomeJson")]
[DllImport(NativeImports.RLNETNATIVE, EntryPoint = "CALoopReportOutcomeJson")]
private static extern int CALoopReportOutcomeJsonNative(IntPtr caLoop, IntPtr eventId, IntPtr outcomeJson, IntPtr apiStatus);

internal static Func<IntPtr, IntPtr, IntPtr, IntPtr, int> CALoopReportOutcomeJsonOverride { get; set; }
Expand All @@ -109,13 +109,13 @@ public static int CALoopReportOutcomeJson(IntPtr caLoop, IntPtr eventId, IntPtr
return CALoopReportOutcomeJsonNative(caLoop, eventId, outcomeJson, apiStatus);
}

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
public static extern int CALoopRefreshModel(IntPtr caLoop, IntPtr apiStatus);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
public static extern void CALoopSetCallback(IntPtr caLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_background_error_callback_t callback = null);

[DllImport("rlnetnative")]
[DllImport(NativeImports.RLNETNATIVE)]
public static extern void CALoopSetTrace(IntPtr caLoop, [MarshalAs(UnmanagedType.FunctionPtr)] managed_trace_callback_t callback = null);
}
}
Expand Down
Loading
Loading