Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rlclientlib and rl.net to support azure credential hooks #609

Merged
merged 19 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
63216c9
added azure_credentials_provider.h to the public headers
v-jameslongo Aug 8, 2024
484c5e5
added memory and azure include to the azure_credentials_provider
v-jameslongo Aug 9, 2024
cd6e87a
added trace logger to the azure credential providers
v-jameslongo Aug 9, 2024
ebfcfe7
removed trace from lock scope
v-jameslongo Aug 9, 2024
14303e8
added set(CMAKE_DEBUG_POSTFIX "") to ensure VS builds have the same l…
v-jameslongo Aug 9, 2024
fc2fc17
updated formatting (according to tidy)
v-jameslongo Aug 9, 2024
060ce8a
moved find dotnet-t4 to FindDotnet.cmake and check it for windows onl…
v-jameslongo Aug 12, 2024
ac470e6
the x64 macos image was changed to macos-latest-large. see (https://g…
v-jameslongo Aug 12, 2024
56ff878
for now, moving to build runner macos-13 supporting x64
v-jameslongo Aug 12, 2024
f05a552
updates to provide default azure credential implementations via the f…
v-jameslongo Aug 14, 2024
65ce2a7
fixed formatting
v-jameslongo Aug 14, 2024
70894c8
added missing call to azure_cred_provider_factory_t dtor in factory_…
v-jameslongo Aug 21, 2024
ae91b1c
update the job name for macos so that we don't need to change the sta…
v-jameslongo Aug 22, 2024
e03dd80
fix using ternary operator expression
v-jameslongo Aug 22, 2024
97c488c
updated the vcpkg build to align the macos package names with the sta…
v-jameslongo Aug 22, 2024
f24f886
the job name should 'macos-latest'
v-jameslongo Aug 22, 2024
6f3ff3d
updated another macos-13 os tag to macos-latest
v-jameslongo Aug 22, 2024
ae768f4
changed macos-13 exact match to startsWith
v-jameslongo Aug 22, 2024
bbd92c1
corrected startsWith 'macos-13' to 'macos'
v-jameslongo Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ if(POLICY CMP0091)
cmake_policy(SET CMP0091 NEW)
endif()

# ensure all of the build tools generate the same output on all platforms
# note: this change was made since building with Ninja does not add suffixes
# but, using the VS generator does.
set(CMAKE_DEBUG_POSTFIX "")

if(WIN32)
# Due to needing to configure the CMAKE platform, this needs to be included before the
# top-level project() declaration.
Expand Down
7 changes: 7 additions & 0 deletions bindings/cs/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/")
include(FindDotnet)

# note: this change was made since building with Ninja does not add suffixes
# but, using the VS generator does. rl.net uses dllimport to load rlnetnative.
# This is a workaround to make sure the correct dll is used.
if (WIN32 AND CMAKE_GENERATOR MATCHES "Visual Studio")
set(CMAKE_DEBUG_POSTFIX "")
endif()

add_subdirectory(rl.net.native)
add_subdirectory(rl.net)
add_subdirectory(rl.net.cli)
Expand Down
6 changes: 6 additions & 0 deletions examples/basic_usage_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ add_executable(basic_usage_cpp.out
)

target_link_libraries(basic_usage_cpp.out PRIVATE rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(basic_usage_cpp.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(basic_usage_cpp.out PRIVATE Azure::azure-identity)
endif()
6 changes: 6 additions & 0 deletions examples/override_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ add_executable(override_interface.out
)

target_link_libraries(override_interface.out PRIVATE rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(override_interface.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(override_interface.out PRIVATE Azure::azure-identity)
endif()
3 changes: 1 addition & 2 deletions examples/rl_sim_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ po::variables_map process_cmd_line(const int argc, char** argv)
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")(
"delay", po::value<int64_t>()->default_value(2000), "Delay between events in ms")(
"quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value<bool>()->default_value(true),
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")(
"azure_oauth_factories", po::value<bool>()->default_value(false), "Use oauth for azure factores. Default false");
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats");

po::variables_map vm;
store(parse_command_line(argc, argv, desc), vm);
Expand Down
18 changes: 10 additions & 8 deletions examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,18 @@ int rl_sim::init_rl()
sender_factory = &factory;
}
// probably incompatible with the throughput option?
else if (_options["azure_oauth_factories"].as<bool>())
{
#ifdef LINK_AZURE_LIBS
// Note: This requires C++14 or better
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4);
reinforcement_learning::register_default_factories_callback(callback);
// Note: The azure_oauth_factories switch has been removed since registering the factories by default does not
// directly utilize the callback. The factory implementations need to be configured in the configuration file
// for them to function (see factory_resolver.h).
// Note: If USE_AZURE_FACTORIES is defined, the default factory implementations will be different (see
// factory_resolver.h).
// Note: This requires C++14 or better since the Azure libraries require C++14 or better.
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&azure_credentials_provider_t::get_token, &_creds, _1, _2, _3, _4);
reinforcement_learning::register_default_factories_callback(callback);
#endif
}

// Initialize the API
_rl = std::unique_ptr<r::live_model>(new r::live_model(config, _on_error, this,
Expand Down
6 changes: 6 additions & 0 deletions examples/test_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ add_executable(rl_test.out
target_include_directories(rl_test.out PRIVATE $<TARGET_PROPERTY:rlclientlib,INCLUDE_DIRECTORIES>)

target_link_libraries(rl_test.out PRIVATE Boost::program_options rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(rl_test.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(rl_test.out PRIVATE Azure::azure-identity)
endif()
114 changes: 86 additions & 28 deletions include/azure_credentials_provider.h
Original file line number Diff line number Diff line change
@@ -1,37 +1,87 @@
#pragma once

#ifdef LINK_AZURE_LIBS
# include "err_constants.h"
# include "oauth_callback_fn.h"

# include <chrono>
# include <mutex>
// These are needed because azure does a bad time conversion
# include <azure/core/datetime.hpp>
# include <azure/core/credentials/credentials.hpp>
# include <exception>
# include <iomanip>
# include <iostream>
# include <memory>
# include <sstream>
# include <mutex>

# ifdef LINK_AZURE_LIBS
# include <azure/core/credentials/credentials.hpp>
# endif
// These are needed because azure does a bad time conversion
# include <time.h>

# include "err_constants.h"
# include "trace_logger.h"
# include <azure/core/datetime.hpp>

namespace reinforcement_learning
{
namespace
{
/**
* @brief Get the GMT offset from the local time.
*/
inline std::chrono::system_clock::duration get_gmt_offset()
{
auto get_time_point = [](std::tm& tm)
{
// set the tm_isdst field to -1 to let mktime determine if DST is in effect
tm.tm_isdst = -1;
return std::chrono::system_clock::from_time_t(std::mktime(&tm));
};
std::time_t now = std::time(nullptr);
std::tm local_tm{};
localtime_s(&local_tm, &now);
std::tm gmt_tm{};
gmtime_s(&gmt_tm, &now);
return get_time_point(local_tm) - get_time_point(gmt_tm);
}
} // namespace

/**
* @brief A template class that provides Azure OAuth credentials.
*
* This class is a template that requires a type T which must be a subclass of
* Azure::Core::Credentials::TokenCredential. It implements the i_oauth_credentials_provider
* interface to provide OAuth tokens for Azure services.
*
* @tparam T The type of the Azure credential, must be a subclass of Azure::Core::Credentials::TokenCredential.
*/
template <typename T>
class azure_credentials_provider
class azure_credentials_provider : public i_oauth_credentials_provider
{
static_assert(std::is_base_of<Azure::Core::Credentials::TokenCredential, T>::value,
"T must be a subclass of Azure::Core::Credentials::TokenCredential");

public:
/**
* @brief Constructs an azure_credentials_provider with the given arguments.
*
* @tparam Args The types of the arguments to forward to the constructor of T.
* @param args The arguments to forward to the constructor of T.
*/
template <typename... Args>
azure_credentials_provider(Args&&... args) : _creds(std::make_unique<T>(std::forward<Args>(args)...))
{
_gmt_offset = get_gmt_offset();
}

int get_credentials(const std::vector<std::string>& scopes, std::string& token_out,
std::chrono::system_clock::time_point& expiry_out, i_trace* trace)
/**
* @brief Default destructor.
*/
~azure_credentials_provider() override = default;

/**
* @brief Retrieves an OAuth token for the given scopes.
*
* @param scopes The scopes for which the token is requested.
* @param token_out The output parameter where the token will be stored.
* @param expiry_out The output parameter where the token expiry time will be stored.
* @param trace The trace object for logging.
* @return int The error code indicating success or failure.
*/
int get_token(const std::vector<std::string>& scopes, std::string& token_out,
std::chrono::system_clock::time_point& expiry_out, i_trace* trace) override
{
using namespace Azure::Core;
using namespace Azure::Core::Credentials;
Expand All @@ -48,17 +98,7 @@ class azure_credentials_provider
}
TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token");
token_out = auth.Token;

// Casting from an azure DateTime object to a time_point does the calculation
// incorrectly. The expiration is returned as a local time, but the library
// assumes that it is GMT, and converts the value incorrectly.
// See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075
// expiry_out = static_cast<std::chrono::system_clock::time_point>(auth.ExpiresOn);
std::string dt_string = auth.ExpiresOn.ToString();
std::tm tm = {};
std::istringstream ss(dt_string);
ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm));
expiry_out = get_expiry_time(auth);
}
catch (AuthenticationException& e)
{
Expand All @@ -78,10 +118,28 @@ class azure_credentials_provider
return error_code::success;
}

private:
/**
* @brief Gets the adjusted expiry time of the given access token.
*
* @param access_token The access token whose expiry time is to be calculated.
* @return std::chrono::system_clock::time_point The calculated expiry time.
* @remarks This function is needed because Azure library returns local time
* instead of GMT time.
*/
std::chrono::system_clock::time_point get_expiry_time(const Azure::Core::Credentials::AccessToken& access_token)
{
// Casting from an azure DateTime object to a time_point does the calculation
// incorrectly. The expiration is returned as a local time, but the library
// assumes that it is GMT, and converts the value incorrectly.
// See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075
return static_cast<std::chrono::system_clock::time_point>(access_token.ExpiresOn) - _gmt_offset;
}

private:
std::unique_ptr<T> _creds;
mutable std::mutex _creds_mtx;
std::chrono::system_clock::duration _gmt_offset;
};
} // namespace reinforcement_learning

#endif
#endif // LINK_AZURE_LIBS
12 changes: 12 additions & 0 deletions include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ const char* const PROTOCOL_VERSION = "protocol.version";
const char* const HTTP_API_KEY = "http.api.key";
const char* const HTTP_API_HEADER_KEY_NAME = "http.api.header.key.name";
const char* const HTTP_API_OAUTH_TOKEN_TYPE = "http.api.oauth.token.type";
const char* const AZURE_OAUTH_CREDENTIAL_TYPE = "azure.oauth.credential.type";
const char* const AZURE_OAUTH_CREDENTIAL_CLIENTID = "azure.oauth.credential.clientid";
const char* const AZURE_OAUTH_CREDENTIAL_TENANTID = "azure.oauth.credential.tenantid";
const char* const AZURE_OAUTH_CREDENTIAL_CLIENTSECRET = "azure.oauth.credential.clientsecret";
const char* const AUDIT_ENABLED = "audit.enabled";
const char* const AUDIT_OUTPUT_PATH = "audit.output.path";

Expand Down Expand Up @@ -120,6 +124,7 @@ const char* const NO_MODEL_DATA = "NO_MODEL_DATA";
const char* const HTTP_MODEL_DATA = "HTTP_MODEL_DATA";
const char* const FILE_MODEL_DATA = "FILE_MODEL_DATA";
const char* const HTTP_MODEL_DATA_OAUTH = "HTTP_MODEL_DATA_OAUTH";
const char* const HTTP_MODEL_DATA_OAUTH_AZ = "HTTP_MODEL_DATA_OAUTH_AZ";
const char* const VW = "VW";
const char* const PASSTHROUGH_PDF_MODEL = "PASSTHROUGH_PDF";
const char* const EPISODE_EH_SENDER = "EPISODE_EH_SENDER";
Expand All @@ -132,8 +137,11 @@ const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER";
const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER";
const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER";
const char* const EPISODE_HTTP_API_SENDER_OAUTH = "EPISODE_HTTP_API_SENDER_OAUTH";
const char* const EPISODE_HTTP_API_SENDER_OAUTH_AZ = "EPISODE_HTTP_API_SENDER_OAUTH_AZ";
const char* const OBSERVATION_HTTP_API_SENDER_OAUTH = "OBSERVATION_HTTP_API_SENDER_OAUTH";
const char* const OBSERVATION_HTTP_API_SENDER_OAUTH_AZ = "OBSERVATION_HTTP_API_SENDER_OAUTH_AZ";
const char* const INTERACTION_HTTP_API_SENDER_OAUTH = "INTERACTION_HTTP_API_SENDER_OAUTH";
const char* const INTERACTION_HTTP_API_SENDER_OAUTH_AZ = "INTERACTION_HTTP_API_SENDER_OAUTH_AZ";
const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER";
const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER";
const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER";
Expand All @@ -146,6 +154,10 @@ const char* const CONTENT_ENCODING_DEDUP = "DEDUP";
const char* const HTTP_API_DEFAULT_HEADER_KEY_NAME = "Ocp-Apim-Subscription-Key";
const char* const HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE = "Bearer";
const char* const TRACE_LOG_LEVEL_DEFAULT = "info";
const char* const AZURE_OAUTH_CREDENTIALS_DEFAULT = "DEFAULT_CREDENTIAL";
const char* const AZURE_OAUTH_CREDENTIALS_MANAGEDIDENTITY = "MANAGED_IDENTITY";
const char* const AZURE_OAUTH_CREDENTIALS_AZURECLI = "AZURECLI";
const char* const AZURE_OAUTH_CREDENTIALS_CLIENTSECRET = "CLIENT_SECRET";

const char* const QUEUE_MODE_DROP = "DROP";
const char* const QUEUE_MODE_BLOCK = "BLOCK";
Expand Down
Loading
Loading