Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rlclientlib and rl.net to support azure credential hooks #609

Merged
merged 19 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
63216c9
added azure_credentials_provider.h to the public headers
v-jameslongo Aug 8, 2024
484c5e5
added memory and azure include to the azure_credentials_provider
v-jameslongo Aug 9, 2024
cd6e87a
added trace logger to the azure credential providers
v-jameslongo Aug 9, 2024
ebfcfe7
removed trace from lock scope
v-jameslongo Aug 9, 2024
14303e8
added set(CMAKE_DEBUG_POSTFIX "") to ensure VS builds have the same l…
v-jameslongo Aug 9, 2024
fc2fc17
updated formatting (according to tidy)
v-jameslongo Aug 9, 2024
060ce8a
moved find dotnet-t4 to FindDotnet.cmake and check it for windows onl…
v-jameslongo Aug 12, 2024
ac470e6
the x64 macos image was changed to macos-latest-large. see (https://g…
v-jameslongo Aug 12, 2024
56ff878
for now, moving to build runner macos-13 supporting x64
v-jameslongo Aug 12, 2024
f05a552
updates to provide default azure credential implementations via the f…
v-jameslongo Aug 14, 2024
65ce2a7
fixed formatting
v-jameslongo Aug 14, 2024
70894c8
added missing call to azure_cred_provider_factory_t dtor in factory_…
v-jameslongo Aug 21, 2024
ae91b1c
update the job name for macos so that we don't need to change the sta…
v-jameslongo Aug 22, 2024
e03dd80
fix using ternary operator expression
v-jameslongo Aug 22, 2024
97c488c
updated the vcpkg build to align the macos package names with the sta…
v-jameslongo Aug 22, 2024
f24f886
the job name should 'macos-latest'
v-jameslongo Aug 22, 2024
6f3ff3d
updated another macos-13 os tag to macos-latest
v-jameslongo Aug 22, 2024
ae768f4
changed macos-13 exact match to startsWith
v-jameslongo Aug 22, 2024
bbd92c1
corrected startsWith 'macos-13' to 'macos'
v-jameslongo Aug 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions bindings/cs/rl.net.native/rl.net.azure_factories.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,38 @@ rl_net_native::azure_factory_oauth_callback_t g_oauth_callback = nullptr;
rl_net_native::azure_factory_oauth_callback_complete_t g_oauth_callback_complete = nullptr;

static int azure_factory_oauth_callback(const std::vector<std::string>& scopes,
std::string& oauth_token, std::chrono::system_clock::time_point& token_expiry)
std::string& oauth_token, std::chrono::system_clock::time_point& token_expiry,
reinforcement_learning::i_trace *trace)
{
if (g_oauth_callback == nullptr) {
return -1;
if (g_oauth_callback == nullptr)
{
return reinforcement_learning::error_code::invalid_argument;
}
// create a null terminated array of scope string pointers
// these are pointers are readonly and owned by the caller
std::vector<const char*> native_scopes;
native_scopes.reserve(scopes.size() + 1);
// create a null terminated array of scope string pointers.
// these are pointers are readonly and owned by the caller.
// since we create the vector of n elements, we can guarantee
// the last element is null.
std::vector<const char*> native_scopes(scopes.size() + 1);
for (int i = 0; i < scopes.size(); ++i)
{
native_scopes.push_back(scopes[i].c_str());
native_scopes[i] = scopes[i].c_str();
}
native_scopes.push_back(nullptr);
// we expect to get a pointer to a null terminated string
// we expect to get a pointer to a null terminated string.
// it's allocated on the managed heap, so we can't free it here
// instead, we will pass it back to the managed code after we copy it
// instead, we will pass it back to the managed code after we copy it.
char *oauth_token_ptr = nullptr;
std::int64_t expiryUnixTime = 0;
auto ret = g_oauth_callback(native_scopes.data(), &oauth_token_ptr, &expiryUnixTime);
if (ret == reinforcement_learning::error_code::success) {
if (ret == reinforcement_learning::error_code::success)
{
TRACE_DEBUG(trace, "rl_net_native::azure_factory_oauth_callback: successfully retrieved token");
oauth_token = oauth_token_ptr;
token_expiry = std::chrono::system_clock::from_time_t(expiryUnixTime);
}
else
{
TRACE_ERROR(trace, "rl_net_native::azure_factory_oauth_callback: failed to retrieve token");
}
g_oauth_callback_complete(oauth_token_ptr, reinforcement_learning::error_code::success);
return ret;
}
Expand All @@ -38,10 +46,11 @@ static int azure_factory_oauth_callback(const std::vector<std::string>& scopes,
API void RegisterDefaultFactoriesCallback(rl_net_native::azure_factory_oauth_callback_t callback,
rl_net_native::azure_factory_oauth_callback_complete_t completion)
{
if (rl_net_native::g_oauth_callback == nullptr) {
if (rl_net_native::g_oauth_callback == nullptr)
{
rl_net_native::g_oauth_callback = callback;
reinforcement_learning::register_default_factories_callback(
reinforcement_learning::oauth_callback_t{rl_net_native::azure_factory_oauth_callback});
reinforcement_learning::oauth_callback_t { rl_net_native::azure_factory_oauth_callback });
rl_net_native::g_oauth_callback_complete = completion;
}
}
2 changes: 2 additions & 0 deletions bindings/cs/rl.net.native/rl.net.azure_factories.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ typedef int (*azure_factory_oauth_callback_t)(const char**, char**, std::int64_t

// Callback function to be called by the C# code to signal the completion of the OAuth token request
// this provides the C# code with the opportunity to free the memory allocated for the token
// tokenStringToFree - pointer to the token string than needs to be freed
// errorCode - for future use
typedef void (*azure_factory_oauth_callback_complete_t)(char*, int);
} // namespace rl_net_native

Expand Down
2 changes: 1 addition & 1 deletion examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ int rl_sim::init_rl()
// Note: This requires C++14 or better
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3);
std::bind(&azure_credentials_provider_t::get_credentials, &_creds, _1, _2, _3, _4);
reinforcement_learning::register_default_factories_callback(callback);
#endif
}
Expand Down
72 changes: 38 additions & 34 deletions include/azure_credentials_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@

#ifdef LINK_AZURE_LIBS

#include "api_status.h"
#include "configuration.h"
#include "future_compat.h"

#include "err_constants.h"
#include "future_compat.h"

#include <azure/core/datetime.hpp>
#include <chrono>
#include <mutex>
// These are needed because azure does a bad time conversion
Expand All @@ -19,34 +11,44 @@
#include <sstream>
#include <memory>

# ifdef LINK_AZURE_LIBS
# include <azure/core/credentials/credentials.hpp>
# endif
#include <azure/core/datetime.hpp>

#ifdef LINK_AZURE_LIBS
#include <azure/core/credentials/credentials.hpp>
#endif

#include "err_constants.h"
#include "trace_logger.h"

namespace reinforcement_learning
{

template<typename T>
class azure_credentials_provider
{
public:
template<typename... Args>
azure_credentials_provider(Args&&... args) :
_creds(std::make_unique<T>(std::forward<Args>(args)...)) {}
_creds(std::make_unique<T>(std::forward<Args>(args)...))
{
}

int get_credentials(const std::vector<std::string>& scopes, std::string& token_out,
std::chrono::system_clock::time_point& expiry_out)
std::chrono::system_clock::time_point& expiry_out, i_trace* trace)
{
using namespace Azure::Core;
using namespace Azure::Core::Credentials;
try
{
Azure::Core::Credentials::TokenRequestContext request_context;
request_context.Scopes = scopes;
TokenRequestContext request_context;
Context context;

Azure::Core::Context context;
std::cout << "fetching token for " << scopes[0] << std::endl;

std::lock_guard<std::mutex> lock(_creds_mtx);
auto auth = _creds->GetToken(request_context, context);
request_context.Scopes = scopes;
AccessToken auth;
{
std::lock_guard<std::mutex> lock(_creds_mtx);
auth = _creds->GetToken(request_context, context);
TRACE_DEBUG(trace, "azure_credentials_provider: successfully retrieved token");
v-jameslongo marked this conversation as resolved.
Show resolved Hide resolved
}
token_out = auth.Token;

// Casting from an azure DateTime object to a time_point does the calculation
Expand All @@ -59,27 +61,29 @@ class azure_credentials_provider
std::istringstream ss(dt_string);
ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm));

return 0;
}
catch (std::exception& e)
{
std::cout << "Error getting auth token: " << e.what();
return error_code::external_error;
}
catch (...)
{
std::cout << "Unknown error while getting auth token";
return error_code::external_error;
}
catch (AuthenticationException& e)
{
TRACE_ERROR(trace, e.what());
return error_code::http_oauth_authentication_error;
}
catch (std::exception& e)
{
TRACE_ERROR(trace, e.what());
return error_code::http_oauth_unexpected_error;
}
catch (...)
{
TRACE_ERROR(trace, "azure_credentials_provider: an unexpected unknown error occurred");
return error_code::http_oauth_unexpected_error;
}
return error_code::success;
}

private:
std::unique_ptr<T> _creds;
mutable std::mutex _creds_mtx;
};

} // namespace reinforcement_learning

#endif
2 changes: 2 additions & 0 deletions include/errors_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,6 @@ ERROR_CODE_DEFINITION(49, baseline_actions_not_defined, "Baseline Actions must b
ERROR_CODE_DEFINITION(50, http_api_key_not_provided, "Http api key must be provided")
ERROR_CODE_DEFINITION(51, http_model_uri_not_provided, "Model Blob URI parameter was not passed in via configuration")
ERROR_CODE_DEFINITION(52, static_model_load_error, "Static model passed in C# layer is not loading properly")
ERROR_CODE_DEFINITION(53, http_oauth_authentication_error, "http request failed to authenticate")
ERROR_CODE_DEFINITION(54, http_oauth_unexpected_error, "http request failed with an unexpected error while retrieving a token")
//! [Error Definitions]
3 changes: 2 additions & 1 deletion include/oauth_callback_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
#include <functional>
#include <string>
#include <vector>
#include "trace_logger.h"

namespace reinforcement_learning
{
using oauth_callback_t =
std::function<int(const std::vector<std::string>&, std::string&, std::chrono::system_clock::time_point&)>;
std::function<int(const std::vector<std::string>&, std::string&, std::chrono::system_clock::time_point&, i_trace* trace)>;
}
2 changes: 1 addition & 1 deletion rlclientlib/utility/api_header_token.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class api_header_token_callback
{
using namespace std::chrono;
system_clock::time_point tp;
RETURN_IF_FAIL(_token_callback(_scopes, _bearer_token, _token_expiry));
RETURN_IF_FAIL(_token_callback(_scopes, _bearer_token, _token_expiry, trace));

if (_bearer_token.empty())
{
Expand Down
Loading