Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/azure oauth #604

Merged
merged 26 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a04d24f
Initial commit for Azure Oauth work
peterychang Jan 29, 2024
a93981f
Fixing blob storage requests
peterychang Jan 30, 2024
ebba307
Merge branch 'master' into dev/azure_oauth
peterychang Feb 2, 2024
0a9f167
remove azure dependencies from mainline cmake files
peterychang Feb 2, 2024
e5ccf4f
Add AzureVcpkg to modules
peterychang Feb 2, 2024
b713eda
Fix keytype conversions
peterychang Feb 2, 2024
729160b
fix bad string conversion
peterychang Feb 2, 2024
9bc06eb
Update callback signature. Conditionally compile azure code
peterychang Feb 2, 2024
9d3ddcd
Add cmake option for azure libs, fix compile issues
peterychang Feb 2, 2024
7d19e78
separate azure dependencies from default
peterychang Feb 2, 2024
246697e
fix conditional compile for azure libs
peterychang Feb 5, 2024
1a9212b
rl_sim tenant id as parameter
peterychang Feb 6, 2024
980f523
testing workflow fixes
peterychang Feb 6, 2024
bbf8cc6
remove rapidjson required version for now
peterychang Feb 6, 2024
9235700
update workflows to fetch full vcpkg submodule
peterychang Feb 6, 2024
b59c0fb
run lint
peterychang Feb 6, 2024
5b5de4d
fixing more workflows
peterychang Feb 6, 2024
7c59f60
fixing clang tidy issues
peterychang Feb 6, 2024
81ff98c
more clang tidy fixes
peterychang Feb 6, 2024
f19c3e8
lint
peterychang Feb 6, 2024
b71e34d
fixing compile issues
peterychang Feb 6, 2024
b45d6a9
clang tidy
peterychang Feb 6, 2024
c70dac0
fix github workflow
peterychang Feb 6, 2024
cb7fb61
Remove unnecessary cmake files/commands
peterychang Feb 7, 2024
ec2f546
review comments
peterychang Feb 8, 2024
44260dc
Merge branch 'master' into dev/azure_oauth
peterychang Feb 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions cmake/Modules/AzureVcpkg.cmake
peterychang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# We need to know an absolute path to our repo root to do things like referencing ./LICENSE.txt file.
set(AZ_ROOT_DIR "${CMAKE_CURRENT_LIST_DIR}/..")

macro(az_vcpkg_integrate)
message("Vcpkg integrate step.")

# AUTO CMAKE_TOOLCHAIN_FILE:
# User can call `cmake -DCMAKE_TOOLCHAIN_FILE="path_to_the_toolchain"` as the most specific scenario.
# As the last alternative (default case), Azure SDK will automatically clone VCPKG folder and set toolchain from there.
if(NOT DEFINED CMAKE_TOOLCHAIN_FILE)
message("CMAKE_TOOLCHAIN_FILE is not defined. Define it for the user.")
# Set AZURE_SDK_DISABLE_AUTO_VCPKG env var to avoid Azure SDK from cloning and setting VCPKG automatically
# This option delegate package's dependencies installation to user.
if(NOT DEFINED ENV{AZURE_SDK_DISABLE_AUTO_VCPKG})
message("AZURE_SDK_DISABLE_AUTO_VCPKG is not defined. Fetch a local copy of vcpkg.")
# GET VCPKG FROM SOURCE
# User can set env var AZURE_SDK_VCPKG_COMMIT to pick the VCPKG commit to fetch
set(VCPKG_COMMIT_STRING 43cf47eccfbe27006cf9534a5db809798f8c37fe) # default SDK tested commit
if(DEFINED ENV{AZURE_SDK_VCPKG_COMMIT})
message("AZURE_SDK_VCPKG_COMMIT is defined. Using that instead of the default.")
set(VCPKG_COMMIT_STRING "$ENV{AZURE_SDK_VCPKG_COMMIT}") # default SDK tested commit
endif()
message("Vcpkg commit string used: ${VCPKG_COMMIT_STRING}")
include(FetchContent)
FetchContent_Declare(
vcpkg
GIT_REPOSITORY https://github.com/microsoft/vcpkg.git
GIT_TAG ${VCPKG_COMMIT_STRING}
)
FetchContent_GetProperties(vcpkg)
# make sure to pull vcpkg only once.
if(NOT vcpkg_POPULATED)
FetchContent_Populate(vcpkg)
endif()
# use the vcpkg source path
set(CMAKE_TOOLCHAIN_FILE "${vcpkg_SOURCE_DIR}/scripts/buildsystems/vcpkg.cmake" CACHE STRING "")
endif()
endif()

# enable triplet customization
if(DEFINED ENV{VCPKG_DEFAULT_TRIPLET} AND NOT DEFINED VCPKG_TARGET_TRIPLET)
set(VCPKG_TARGET_TRIPLET "$ENV{VCPKG_DEFAULT_TRIPLET}" CACHE STRING "")
endif()
message("Vcpkg integrate step - DONE.")
endmacro()

macro(az_vcpkg_portfile_prep targetName fileName contentToRemove)
# with sdk/<lib>/vcpkg/<fileName>
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/vcpkg/${fileName}" fileContents)

# Windows -> Unix line endings
string(FIND fileContents "\r\n" crLfPos)

if (crLfPos GREATER -1)
string(REPLACE "\r\n" "\n" fileContents ${fileContents})
endif()

# remove comment header
string(REPLACE "${contentToRemove}" "" fileContents ${fileContents})

# undo Windows -> Unix line endings (if applicable)
if (crLfPos GREATER -1)
string(REPLACE "\n" "\r\n" fileContents ${fileContents})
endif()
unset(crLfPos)

# output to an intermediate location
file (WRITE "${CMAKE_BINARY_DIR}/vcpkg_prep/${targetName}/${fileName}" ${fileContents})
unset(fileContents)

# Produce the files to help with the vcpkg release.
# Go to the /out/build/<cfg>/vcpkg directory, and copy (merge) "ports" folder to the vcpkg repo.
# Then, update the portfile.cmake file SHA512 from "1" to the actual hash (a good way to do it is to uninstall a package,
# clean vcpkg/downloads, vcpkg/buildtrees, run "vcpkg install <pkg>", and get the SHA from the error message).
configure_file(
"${CMAKE_BINARY_DIR}/vcpkg_prep/${targetName}/${fileName}"
"${CMAKE_BINARY_DIR}/vcpkg/ports/${targetName}-cpp/${fileName}"
@ONLY
)
endmacro()

macro(az_vcpkg_export targetName macroNamePart dllImportExportHeaderPath)
foreach(vcpkgFile "vcpkg.json" "portfile.cmake")
az_vcpkg_portfile_prep(
"${targetName}"
"${vcpkgFile}"
"# Copyright (c) Microsoft Corporation.\n# Licensed under the MIT License.\n\n"
)
endforeach()

# Standard names for folders such as "bin", "lib", "include". We could hardcode, but some other libs use it too (curl).
include(GNUInstallDirs)

# When installing, copy our "inc" directory (headers) to "include" directory at the install location.
install(DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/inc/azure/" DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/azure")

# Copy license as "copyright" (vcpkg dictates naming and location).
install(FILES "${AZ_ROOT_DIR}/LICENSE.txt" DESTINATION "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp" RENAME "copyright")

# Indicate where to install targets. Mirrors what other ports do.
install(
TARGETS "${targetName}"
EXPORT "${targetName}-cppTargets"
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # DLLs (if produced by build) go to "/bin"
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} # static .lib files
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} # .lib files for DLL build
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} # headers
)

# If building a Windows DLL, patch the dll_import_export.hpp
if(WIN32 AND BUILD_SHARED_LIBS)
add_compile_definitions(AZ_${macroNamePart}_BEING_BUILT)
target_compile_definitions(${targetName} PUBLIC AZ_${macroNamePart}_DLL)

set(AZ_${macroNamePart}_DLL_INSTALLED_AS_PACKAGE "*/ + 1 /*")
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/inc/${dllImportExportHeaderPath}"
"${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_INCLUDEDIR}/${dllImportExportHeaderPath}"
@ONLY
)
unset(AZ_${macroNamePart}_DLL_INSTALLED_AS_PACKAGE)

get_filename_component(dllImportExportHeaderDir ${dllImportExportHeaderPath} DIRECTORY)
install(
FILES "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_INCLUDEDIR}/${dllImportExportHeaderPath}"
DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/${dllImportExportHeaderDir}"
)
unset(dllImportExportHeaderDir)
endif()

# Export the targets file itself.
install(
EXPORT "${targetName}-cppTargets"
DESTINATION "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp"
NAMESPACE Azure:: # Not the C++ namespace, but a namespace in terms of cmake.
FILE "${targetName}-cppTargets.cmake"
)

# configure_package_config_file(), write_basic_package_version_file()
include(CMakePackageConfigHelpers)

# Produce package config file.
configure_package_config_file(
"${CMAKE_CURRENT_SOURCE_DIR}/vcpkg/Config.cmake.in"
"${targetName}-cppConfig.cmake"
INSTALL_DESTINATION "${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp"
PATH_VARS
CMAKE_INSTALL_LIBDIR)

# Produce version file.
write_basic_package_version_file(
"${targetName}-cppConfigVersion.cmake"
VERSION ${AZ_LIBRARY_VERSION} # the version that we extracted from package_version.hpp
COMPATIBILITY SameMajorVersion
)

# Install package config and version files.
install(
FILES
"${CMAKE_CURRENT_BINARY_DIR}/${targetName}-cppConfig.cmake"
"${CMAKE_CURRENT_BINARY_DIR}/${targetName}-cppConfigVersion.cmake"
DESTINATION
"${CMAKE_INSTALL_DATAROOTDIR}/${targetName}-cpp" # to shares/<our_pkg>
)

# Export all the installs above as package.
export(PACKAGE "${targetName}-cpp")
endmacro()
23 changes: 22 additions & 1 deletion examples/rl_sim_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
add_executable(rl_sim_cpp.out
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_LIST_DIR}/../../cmake/Modules/")

if(RL_CXX_STANDARD GREATER_EQUAL 14 AND vw_USE_AZURE_FACTORIES)
include(AzureVcpkg)
az_vcpkg_integrate()
endif()

set(RL_SIM_SOURCES
main.cc
person.cc
robot_joint.cc
rl_sim.cc
)
if(RL_CXX_STANDARD GREATER_EQUAL 14 AND vw_USE_AZURE_FACTORIES)
list(APPEND RL_SIM_SOURCES
azure_credentials.cc
)
endif()

add_executable(rl_sim_cpp.out
${RL_SIM_SOURCES}
)

target_link_libraries(rl_sim_cpp.out PRIVATE Boost::program_options rlclientlib)

if(RL_CXX_STANDARD GREATER_EQUAL 14 AND vw_USE_AZURE_FACTORIES)
find_package(azure-identity-cpp CONFIG REQUIRED)
target_link_libraries(rl_sim_cpp.out PRIVATE Azure::azure-identity)
endif()
63 changes: 63 additions & 0 deletions examples/rl_sim_cpp/azure_credentials.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "azure_credentials.h"
#include "err_constants.h"
#include "future_compat.h"

#include <azure/core/datetime.hpp>
#include <chrono>
// These are needed because azure does a bad time conversion
#include <iomanip>
#include <sstream>

#include <exception>
#include <iostream>

using namespace reinforcement_learning;

Azure::Identity::AzureCliCredentialOptions AzureCredentials::create_options()
{
Azure::Identity::AzureCliCredentialOptions options;
options.TenantId = _tenant_id;
options.AdditionallyAllowedTenants.push_back("*");
return options;
}

AzureCredentials::AzureCredentials()
: _creds(create_options())
{}

int AzureCredentials::get_credentials(std::string& token_out, std::chrono::system_clock::time_point& expiry_out,
const std::vector<std::string>& scopes)
{
#ifdef HAS_STD14
Azure::Core::Credentials::TokenRequestContext request_context;
request_context.Scopes = scopes;
// TODO: needed?
request_context.TenantId = _tenant_id;
Azure::Core::Context context;
try {
auto auth = _creds.GetToken(request_context, context);
token_out = auth.Token;

// Casting from an azure DateTime object to a time_point does the calculation
// incorrectly. The expiration is returned as a local time, but the library
// assumes that it is GMT, and converts the value incorrectly.
// See: https://github.com/Azure/azure-sdk-for-cpp/issues/5075
//expiry_out = static_cast<std::chrono::system_clock::time_point>(auth.ExpiresOn);
std::string dt_string = auth.ExpiresOn.ToString();
std::tm tm = {};
std::istringstream ss(dt_string);
ss >> std::get_time(&tm, "%Y-%m-%dT%H:%M:%SZ");
expiry_out = std::chrono::system_clock::from_time_t(std::mktime(&tm));

}
catch(std::exception& e){
std::cout << "Error getting auth token: " << e.what();
return error_code::external_error;
}
catch(std::exception& e){
std::cout << "Unknown error while getting auth token";
return error_code::external_error;
}
#endif
return error_code::success;
}
29 changes: 29 additions & 0 deletions examples/rl_sim_cpp/azure_credentials.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once
#include <memory>

#include "api_status.h"
#include "configuration.h"

#include "future_compat.h"

#include <azure/identity/default_azure_credential.hpp>
#include <azure/identity/azure_cli_credential.hpp>

#include <chrono>
#include <string>

class AzureCredentials
peterychang marked this conversation as resolved.
Show resolved Hide resolved
{
public:
AzureCredentials();
int get_credentials(std::string& token_out, std::chrono::system_clock::time_point& expiry_out,
peterychang marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<std::string>& scopes);
private:
#ifdef HAS_STD14
Azure::Identity::AzureCliCredentialOptions create_options();

//Azure::Identity::DefaultAzureCredential _creds;
Azure::Identity::AzureCliCredential _creds;
std::string _tenant_id = "<tenant_id>";
#endif
};
3 changes: 2 additions & 1 deletion examples/rl_sim_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ po::variables_map process_cmd_line(const int argc, char** argv)
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")(
"delay", po::value<int64_t>()->default_value(2000), "Delay between events in ms")(
"quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value<bool>()->default_value(true),
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats");
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats")(
"azure_oauth_factories", po::value<bool>()->default_value(false), "Use oauth for azure factores. Default false");

po::variables_map vm;
store(parse_command_line(argc, argv, desc), vm);
Expand Down
14 changes: 13 additions & 1 deletion examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
#include "simulation_stats.h"
#include "trace_logger.h"

#include "future_compat.h"

#include <boost/uuid/random_generator.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <chrono>
#include <cmath>
#include <functional>
#include <thread>

using namespace std;
Expand Down Expand Up @@ -487,6 +490,15 @@ int rl_sim::init_rl()
wrap_sender_generate_for_throughput_sender(reinforcement_learning::value::EPISODE_HTTP_API_SENDER));
sender_factory = &factory;
}
// probably incompatible with the throughput option?
else if (_options["azure_oauth_factories"].as<bool>())
{
// Note: This requires C++14 or better
using namespace std::placeholders;
reinforcement_learning::oauth_callback_t callback =
std::bind(&AzureCredentials::get_credentials, &_creds, _1, _2, _3);
reinforcement_learning::register_default_factories_callback(callback);
}

// Initialize the API
_rl = std::unique_ptr<r::live_model>(new r::live_model(config, _on_error, this,
Expand Down Expand Up @@ -699,4 +711,4 @@ std::string get_dist_str(const reinforcement_learning::decision_response& respon
}
ret += ")";
return ret;
}
}
2 changes: 2 additions & 0 deletions examples/rl_sim_cpp/rl_sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "live_model.h"
#include "person.h"
#include "robot_joint.h"
#include "azure_credentials.h"

#include <boost/program_options.hpp>

Expand Down Expand Up @@ -177,4 +178,5 @@ class rl_sim
int64_t _delay = 2000;
bool _quiet = false;
bool _random_ids = true;
AzureCredentials _creds;
};
6 changes: 6 additions & 0 deletions include/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const char* const LEARNING_MODE = "rank.learning.mode";
const char* const PROTOCOL_VERSION = "protocol.version";
const char* const HTTP_API_KEY = "http.api.key";
const char* const HTTP_API_HEADER_KEY_NAME = "http.api.header.key.name";
const char* const HTTP_API_OAUTH_TOKEN_TYPE = "http.api.oauth.token.type";
const char* const AUDIT_ENABLED = "audit.enabled";
const char* const AUDIT_OUTPUT_PATH = "audit.output.path";

Expand Down Expand Up @@ -118,6 +119,7 @@ const char* const AZURE_STORAGE_BLOB = "AZURE_STORAGE_BLOB";
const char* const NO_MODEL_DATA = "NO_MODEL_DATA";
const char* const HTTP_MODEL_DATA = "HTTP_MODEL_DATA";
const char* const FILE_MODEL_DATA = "FILE_MODEL_DATA";
const char* const HTTP_MODEL_DATA_OAUTH = "HTTP_MODEL_DATA_OAUTH";
const char* const VW = "VW";
const char* const PASSTHROUGH_PDF_MODEL = "PASSTHROUGH_PDF";
const char* const EPISODE_EH_SENDER = "EPISODE_EH_SENDER";
Expand All @@ -129,6 +131,9 @@ const char* const INTERACTION_FILE_SENDER = "INTERACTION_FILE_SENDER";
const char* const EPISODE_HTTP_API_SENDER = "EPISODE_HTTP_API_SENDER";
const char* const OBSERVATION_HTTP_API_SENDER = "OBSERVATION_HTTP_API_SENDER";
const char* const INTERACTION_HTTP_API_SENDER = "INTERACTION_HTTP_API_SENDER";
const char* const EPISODE_HTTP_API_SENDER_OAUTH = "EPISODE_HTTP_API_SENDER_OAUTH";
const char* const OBSERVATION_HTTP_API_SENDER_OAUTH = "OBSERVATION_HTTP_API_SENDER_OAUTH";
const char* const INTERACTION_HTTP_API_SENDER_OAUTH = "INTERACTION_HTTP_API_SENDER_OAUTH";
const char* const NULL_TRACE_LOGGER = "NULL_TRACE_LOGGER";
const char* const CONSOLE_TRACE_LOGGER = "CONSOLE_TRACE_LOGGER";
const char* const NULL_TIME_PROVIDER = "NULL_TIME_PROVIDER";
Expand All @@ -139,6 +144,7 @@ const char* const LEARNING_MODE_LOGGINGONLY = "LOGGINGONLY";
const char* const CONTENT_ENCODING_IDENTITY = "IDENTITY";
const char* const CONTENT_ENCODING_DEDUP = "DEDUP";
const char* const HTTP_API_DEFAULT_HEADER_KEY_NAME = "Ocp-Apim-Subscription-Key";
const char* const HTTP_API_DEFAULT_OAUTH_TOKEN_TYPE = "Bearer";
const char* const TRACE_LOG_LEVEL_DEFAULT = "info";

const char* const QUEUE_MODE_DROP = "DROP";
Expand Down
Loading
Loading