Skip to content

Commit

Permalink
updates to provide default azure credential implementations via the f…
Browse files Browse the repository at this point in the history
…actory resolver

cmake changes to link with azure identity when RL_LINK_AZURE_LIBS is defined
  • Loading branch information
v-jameslongo committed Aug 14, 2024
1 parent 56ff878 commit f05a552
Show file tree
Hide file tree
Showing 21 changed files with 492 additions and 87 deletions.
5 changes: 0 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ if(POLICY CMP0091)
cmake_policy(SET CMP0091 NEW)
endif()

# ensure all of the build tools generate the same output on all platforms
# note: this change was made since building with Ninja does not add suffixes
# but, using the VS generator does.
set(CMAKE_DEBUG_POSTFIX "")

if(WIN32)
# Due to needing to configure the CMAKE platform, this needs to be included before the
# top-level project() declaration.
Expand Down
7 changes: 7 additions & 0 deletions bindings/cs/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/Modules/")
include(FindDotnet)

# note: this change was made since building with Ninja does not add suffixes
# but, using the VS generator does. rl.net uses dllimport to load rlnetnative.
# This is a workaround to make sure the correct dll is used.
if (WIN32 AND CMAKE_GENERATOR MATCHES "Visual Studio")
set(CMAKE_DEBUG_POSTFIX "")
endif()

add_subdirectory(rl.net.native)
add_subdirectory(rl.net)
add_subdirectory(rl.net.cli)
Expand Down
6 changes: 6 additions & 0 deletions examples/basic_usage_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ add_executable(basic_usage_cpp.out
)

target_link_libraries(basic_usage_cpp.out PRIVATE rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(basic_usage_cpp.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(basic_usage_cpp.out PRIVATE Azure::azure-identity)
endif()
6 changes: 6 additions & 0 deletions examples/override_interface/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@ add_executable(override_interface.out
)

target_link_libraries(override_interface.out PRIVATE rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(override_interface.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(override_interface.out PRIVATE Azure::azure-identity)
endif()
3 changes: 1 addition & 2 deletions examples/rl_sim_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ po::variables_map process_cmd_line(const int argc, char** argv)
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")(
"delay", po::value<int64_t>()->default_value(2000), "Delay between events in ms")(
"quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value<bool>()->default_value(true),
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")(
"azure_oauth_factories", po::value<bool>()->default_value(false), "Use oauth for azure factores. Default false");
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats");

po::variables_map vm;
store(parse_command_line(argc, argv, desc), vm);
Expand Down
18 changes: 10 additions & 8 deletions examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,16 +490,18 @@ int rl_sim::init_rl()
sender_factory = &factory;
}
// probably incompatible with the throughput option?
else if (_options["azure_oauth_factories"].as<bool>())
{
#ifdef LINK_AZURE_LIBS
// Note: This requires C++14 or better
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4);
reinforcement_learning::register_default_factories_callback(callback);
// Note: The azure_oauth_factories switch has been removed since registering the factories by default does not
// directly utilize the callback. The factory implementations need to be configured in the configuration file
// for them to function (see factory_resolver.h).
// Note: If USE_AZURE_FACTORIES is defined, the default factory implementations will be different (see
// factory_resolver.h).
// Note: This requires C++14 or better since the Azure libraries require C++14 or better.
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&azure_credentials_provider_t::get_token, &_creds, _1, _2, _3, _4);
reinforcement_learning::register_default_factories_callback(callback);
#endif
}

// Initialize the API
_rl = std::unique_ptr<r::live_model>(new r::live_model(config, _on_error, this,
Expand Down
6 changes: 6 additions & 0 deletions examples/test_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ add_executable(rl_test.out
target_include_directories(rl_test.out PRIVATE $<TARGET_PROPERTY:rlclientlib,INCLUDE_DIRECTORIES>)

target_link_libraries(rl_test.out PRIVATE Boost::program_options rlclientlib)

if(RL_LINK_AZURE_LIBS)
target_compile_definitions(rl_test.out PRIVATE LINK_AZURE_LIBS)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(rl_test.out PRIVATE Azure::azure-identity)
endif()
114 changes: 86 additions & 28 deletions include/azure_credentials_provider.h
Original file line number Diff line number Diff line change
@@ -1,37 +1,87 @@
#pragma once

#ifdef LINK_AZURE_LIBS
# include "err_constants.h"
# include "oauth_callback_fn.h"

# include <chrono>
# include <mutex>
// These are needed because azure does a bad time conversion
# include <azure/core/datetime.hpp>
# include <azure/core/credentials/credentials.hpp>
# include <exception>
# include <iomanip>
# include <iostream>
# include <memory>
# include <sstream>
# include <mutex>

# ifdef LINK_AZURE_LIBS
# include <azure/core/credentials/credentials.hpp>
# endif
// These are needed because azure does a bad time conversion
# include <time.h>

# include "err_constants.h"
# include "trace_logger.h"
# include <azure/core/datetime.hpp>

namespace reinforcement_learning
{
namespace
{
/**
* @brief Get the GMT offset from the local time.
*/
inline std::chrono::system_clock::duration get_gmt_offset()
{
auto get_time_point = [](std::tm& tm)
{
// set the tm_isdst field to -1 to let mktime determine if DST is in effect
tm.tm_isdst = -1;
return std::chrono::system_clock::from_time_t(std::mktime(&tm));
};
std::time_t now = std::time(nullptr);
std::tm local_tm{};
localtime_s(&local_tm, &now);
std::tm gmt_tm{};
gmtime_s(&gmt_tm, &now);
return get_time_point(local_tm) - get_time_point(gmt_tm);
}
} // namespace

/**
* @brief A template class that provides Azure OAuth credentials.
*
* This class is a template that requires a type T which must be a subclass of
* Azure::Core::Credentials::TokenCredential. It implements the i_oauth_credentials_provider
* interface to provide OAuth tokens for Azure services.
*
* @tparam T The type of the Azure credential, must be a subclass of Azure::Core::Credentials::TokenCredential.
*/
template <typename T>
class azure_credentials_provider
class azure_credentials_provider : public i_oauth_credentials_provider
{
static_assert(std::is_base_of<Azure::Core::Credentials::TokenCredential, T>::value,
"T must be a subclass of Azure::Core::Credentials::TokenCredential");

public:
/**
* @brief Constructs an azure_credentials_provider with the given arguments.
*
* @tparam Args The types of the arguments to forward to the constructor of T.
* @param args The arguments to forward to the constructor of T.
*/
template <typename... Args>
azure_credentials_provider(Args&&... args) : _creds(std::make_unique<T>(std::forward<Args>(args)...))
{
_gmt_offset = get_gmt_offset();
}

int get_credentials(const std::vector<std::string>& scopes, std::string& token_out,
std::chrono::system_clock::time_point& expiry_out, i_trace* trace)
/**
* @brief Default destructor.
*/
~azure_credentials_provider() override = default;

/**
* @brief Retrieves an OAuth token for the given scopes.
*
* @param scopes The scopes for which the token is requested.
* @param token_out The output parameter where the token will be stored.
* @param expiry_out The output parameter where the token expiry time will be stored.
* @param trace The trace object for logging.
* @return int The error code indicating success or failure.
*/
int get_token(const std::vector<std::string>& scopes, std::string& token_out,
std::chrono::system_clock::time_point& expiry_out, i_trace* trace) override
{
using namespace Azure::Core;
using namespace Azure::Core::Credentials;
Expand All @@ -48,17 +98,7 @@ class azure_credentials_provider
}
TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token");
token_out = auth.Token;

// Casting from an azure DateTime object to a time_point does the calculation
// incorrectly. The expiration is returned as a local time, but the library
// assumes that it is GMT, and converts the value incorrectly.
// See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075
// expiry_out = static_cast<std::chrono::system_clock::time_point>(auth.ExpiresOn);
std::string dt_string = auth.ExpiresOn.ToString();
std::tm tm = {};
std::istringstream ss(dt_string);
ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm));
expiry_out = get_expiry_time(auth);
}
catch (AuthenticationException& e)
{
Expand All @@ -78,10 +118,28 @@ class azure_credentials_provider
return error_code::success;
}

private:
/**
* @brief Gets the adjusted expiry time of the given access token.
*
* @param access_token The access token whose expiry time is to be calculated.
* @return std::chrono::system_clock::time_point The calculated expiry time.
* @remarks This function is needed because Azure library returns local time
* instead of GMT time.
*/
std::chrono::system_clock::time_point get_expiry_time(const Azure::Core::Credentials::AccessToken& access_token)
{
// Casting from an azure DateTime object to a time_point does the calculation
// incorrectly. The expiration is returned as a local time, but the library
// assumes that it is GMT, and converts the value incorrectly.
// See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075
return static_cast<std::chrono::system_clock::time_point>(access_token.ExpiresOn) - _gmt_offset;
}

private:
std::unique_ptr<T> _creds;
mutable std::mutex _creds_mtx;
std::chrono::system_clock::duration _gmt_offset;
};
} // namespace reinforcement_learning

#endif
#endif // LINK_AZURE_LIBS
12 changes: 12 additions & 0 deletions include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ const char* const PROTOCOL_VERSION = "protocol.version";
const char* const HTTP_API_KEY = "http.api.key";
const char* const HTTP_API_HEADER_KEY_NAME = "http.api.header.key.name";
const char* const HTTP_API_OAUTH_TOKEN_TYPE = "http.api.oauth.token.type";
const char* const AZURE_OAUTH_CREDENTIAL_TYPE = "azure.oauth.credential.type";
const char* const AZURE_OAUTH_CREDENTIAL_CLIENTID = "azure.oauth.credential.clientid";
const char* const AZURE_OAUTH_CREDENTIAL_TENANTID = "azure.oauth.credential.tenantid";
const char* const AZURE_OAUTH_CREDENTIAL_CLIENTSECRET = "azure.oauth.credential.clientsecret";
const char* const AUDIT_ENABLED = "audit.enabled";
const char* const AUDIT_OUTPUT_PATH = "audit.output.path";

Expand Down Expand Up @@ -120,6 +124,7 @@ const char* const NO_MODEL_DATA = "NO_MODEL_DATA";
const char* const HTTP_MODEL_DATA = "HTTP_MODEL_DATA";
const char* const FILE_MODEL_DATA = "FILE_MODEL_DATA";
const char* const HTTP_MODEL_DATA_OAUTH = "HTTP_MODEL_DATA_OAUTH";
const char* const HTTP_MODEL_DATA_OAUTH_AZ = "HTTP_MODEL_DATA_OAUTH_AZ";
const char* const VW = "VW";
const char* const PASSTHROUGH_PDF_MODEL = "PASSTHROUGH_PDF";
const char* const EPISODE_EH_SENDER = "EPISODE_EH_SENDER";
Expand All @@ -132,8 +137,11 @@ const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER";
const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER";
const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER";
const char* const EPISODE_HTTP_API_SENDER_OAUTH = "EPISODE_HTTP_API_SENDER_OAUTH";
const char* const EPISODE_HTTP_API_SENDER_OAUTH_AZ = "EPISODE_HTTP_API_SENDER_OAUTH_AZ";
const char* const OBSERVATION_HTTP_API_SENDER_OAUTH = "OBSERVATION_HTTP_API_SENDER_OAUTH";
const char* const OBSERVATION_HTTP_API_SENDER_OAUTH_AZ = "OBSERVATION_HTTP_API_SENDER_OAUTH_AZ";
const char* const INTERACTION_HTTP_API_SENDER_OAUTH = "INTERACTION_HTTP_API_SENDER_OAUTH";
const char* const INTERACTION_HTTP_API_SENDER_OAUTH_AZ = "INTERACTION_HTTP_API_SENDER_OAUTH_AZ";
const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER";
const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER";
const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER";
Expand All @@ -146,6 +154,10 @@ const char* const CONTENT_ENCODING_DEDUP = "DEDUP";
const char* const HTTP_API_DEFAULT_HEADER_KEY_NAME = "Ocp-Apim-Subscription-Key";
const char* const HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE = "Bearer";
const char* const TRACE_LOG_LEVEL_DEFAULT = "info";
const char* const AZURE_OAUTH_CREDENTIALS_DEFAULT = "DEFAULT_CREDENTIAL";
const char* const AZURE_OAUTH_CREDENTIALS_MANAGEDIDENTITY = "MANAGED_IDENTITY";
const char* const AZURE_OAUTH_CREDENTIALS_AZURECLI = "AZURECLI";
const char* const AZURE_OAUTH_CREDENTIALS_CLIENTSECRET = "CLIENT_SECRET";

const char* const QUEUE_MODE_DROP = "DROP";
const char* const QUEUE_MODE_BLOCK = "BLOCK";
Expand Down
Loading

0 comments on commit f05a552

Please sign in to comment.