Skip to content

Commit

Permalink
fix: Pull in latest VW flatbuffer parser API changes/fixes
Browse files Browse the repository at this point in the history
* Also implement a mechanism to manage collisions between RL and VW api_status.h
  • Loading branch information
lokitoth committed Feb 16, 2024
1 parent 27cd823 commit 77f73d5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
1 change: 1 addition & 0 deletions rlclientlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ set(PROJECT_PRIVATE_HEADERS
vw_model/pdf_model.h
vw_model/safe_vw.h
vw_model/vw_model.h
vw_model/vw_api_status_interop.h
)

if(vw_USE_AZURE_FACTORIES)
Expand Down
28 changes: 22 additions & 6 deletions rlclientlib/vw_model/safe_vw.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "safe_vw.h"
#include "constants.h"
#include "vw_api_status_interop.h"
#include "api_status.h"


// VW headers
#include "vw/config/options.h"
#include "vw/core/debug_print.h"
Expand Down Expand Up @@ -188,7 +188,23 @@ namespace detail
template <bool audit>
static void parse_context(VW::workspace& w, string_view context, VW::example_factory_t ex_fac, VW::multi_ex& examples, VW::example_sink_f ex_sink)
{
VW::parsers::flatbuffer::read_span_flatbuffer(&w, reinterpret_cast<const uint8_t*>(context.data()), context.size(), ex_fac, examples, ex_sink);
VW::experimental::api_status vw_status;
if (VW::parsers::flatbuffer::read_span_flatbuffer(&w, reinterpret_cast<const uint8_t*>(context.data()), context.size(), ex_fac, examples, ex_sink, &vw_status) !=
VW::experimental::error_code::success)
{
std::stringstream sstream;
sstream << "Failed to parse flatbuffer: " << vw_status.get_error_msg();

// This is a bit unfortunate, but the APIs around safe_vw were built with
// the original VW error handling (THROW() macro) in mind, and catching the
// exception at the level of vw_model. Ideally we should do the work to
// define a consistent error handling model in VW, and use that here.
//
// In the short term, wrap the error in an exception and throw it up
// one level.
throw std::runtime_error(sstream.str().c_str());
}

}
};

Expand All @@ -197,7 +213,7 @@ namespace detail
{
examples.push_back(&ex_fac());
ensure_audit_buffer<audit>(w);

switch (input_format)
{
case input_serialization::vwjson:
Expand All @@ -223,7 +239,7 @@ void safe_vw::parse_context(string_view context, VW::multi_ex& examples)
_example_pool.emplace_back(ex);
}
};


if (_vw->output_config.audit)
{
Expand Down Expand Up @@ -286,8 +302,8 @@ void safe_vw::rank(string_view context, std::vector<int>& actions, std::vector<f

// clean up examples and push examples back into pool for re-use
auto scope_guard = VW::scope_exit([this, &examples] {
for (auto&& ex : examples)
{
for (auto&& ex : examples)
{
ex->pred.a_s.clear();
_example_pool.emplace_back(ex);
}
Expand Down
20 changes: 20 additions & 0 deletions rlclientlib/vw_model/vw_api_status_interop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "vw/core/api_status.h"
#include "vw/core/error_constants.h"

#ifdef RETURN_ERROR
#undef RETURN_ERROR
#endif

#ifdef RETURN_ERROR_ARG
#undef RETURN_ERROR_ARG
#endif

#ifdef RETURN_ERROR_LS
#undef RETURN_ERROR_LS
#endif

#ifdef RETURN_IF_FAIL
#undef RETURN_IF_FAIL
#endif

0 comments on commit 77f73d5

Please sign in to comment.