diff --git a/common/value.cc b/common/value.cc index a559efee1..07a62e1c4 100644 --- a/common/value.cc +++ b/common/value.cc @@ -270,15 +270,12 @@ bool Value::IsZeroValue() const { } std::ostream& operator<<(std::ostream& out, const Value& value) { - value.AssertIsValid(); return absl::visit( [&out](const auto& alternative) -> std::ostream& { if constexpr (std::is_same_v< absl::remove_cvref_t, absl::monostate>) { - // In optimized builds, we do nothing. In debug builds we cannot - // reach here. - return out; + return out << "default ctor Value"; } else { return out << alternative; } @@ -513,15 +510,12 @@ bool ValueView::IsZeroValue() const { } std::ostream& operator<<(std::ostream& out, ValueView value) { - value.AssertIsValid(); return absl::visit( [&out](auto alternative) -> std::ostream& { if constexpr (std::is_same_v< absl::remove_cvref_t, absl::monostate>) { - // In optimized builds, we do nothing. In debug builds we cannot - // reach here. - return out; + return out << "default ctor ValueView"; } else { return out << alternative; } diff --git a/common/value.h b/common/value.h index 28ea3f9ca..b0aa445fd 100644 --- a/common/value.h +++ b/common/value.h @@ -84,6 +84,9 @@ class Value final { : variant_((other.AssertIsValid(), other.variant_)) {} Value& operator=(const Value& other) { + if (this == std::addressof(other)) { + return *this; + } other.AssertIsValid(); ABSL_DCHECK(this != std::addressof(other)) << "Value should not be copied to itself"; diff --git a/common/value_test.cc b/common/value_test.cc index 80379ecf8..24b879c29 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -44,13 +44,14 @@ TEST(Value, GetTypeName) { EXPECT_DEBUG_DEATH(static_cast(moved_from_value.GetTypeName()), _); } -TEST(Value, DebugStringDebugDeath) { +TEST(Value, DebugStringUinitializedValue) { Value moved_from_value = BoolValue(true); Value value = std::move(moved_from_value); IS_INITIALIZED(moved_from_value); static_cast(value); std::ostringstream out; - EXPECT_DEBUG_DEATH(static_cast(out << moved_from_value), _); + out << moved_from_value; + EXPECT_EQ(out.str(), "default ctor Value"); } TEST(Value, NativeValueIdDebugDeath) { diff --git a/common/value_testing.cc b/common/value_testing.cc index 3e0977172..4306936dc 100644 --- a/common/value_testing.cc +++ b/common/value_testing.cc @@ -28,9 +28,7 @@ namespace cel { -void PrintTo(const Value& value, std::ostream* os) { - *os << value.DebugString() << "\n"; -} +void PrintTo(const Value& value, std::ostream* os) { *os << value << "\n"; } namespace test { namespace { diff --git a/conformance/BUILD b/conformance/BUILD index 09ff3acaf..aa9d7fcaf 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -167,6 +167,7 @@ cc_binary( for args in [ [], ["--opt"], + ["--recursive"], ] ] @@ -236,6 +237,11 @@ cc_binary( "--arena", "--opt", ], + [ + "--modern", + "--arena", + "--recursive", + ], ] ] diff --git a/conformance/server.cc b/conformance/server.cc index 7ca2fd64f..b7d3a74bb 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -78,6 +78,9 @@ ABSL_FLAG( ABSL_FLAG(bool, arena, false, "Use arena memory manager (default: global heap ref-counted). Only " "affects the modern implementation"); +ABSL_FLAG(bool, recursive, false, + "Enable recursive plans. Depth limited to slightly more than the " + "default nesting limit."); namespace google::api::expr::runtime { @@ -163,6 +166,10 @@ class LegacyConformanceServiceImpl : public ConformanceServiceInterface { options.constant_arena = constant_arena; } + if (absl::GetFlag(FLAGS_recursive)) { + options.max_recursion_depth = 48; + } + std::unique_ptr builder = CreateCelExpressionBuilder(options); auto type_registry = builder->GetTypeRegistry(); @@ -284,6 +291,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { options.enable_timestamp_duration_overflow_errors = true; options.enable_heterogeneous_equality = true; options.enable_empty_wrapper_null_unboxing = true; + if (absl::GetFlag(FLAGS_recursive)) { + options.max_recursion_depth = 48; + } return absl::WrapUnique( new ModernConformanceServiceImpl(options, use_arena, optimize)); diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 01ca54078..137b3a5e7 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -28,8 +28,12 @@ cc_library( "//base:ast", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", + "//common:native_type", "//common:value", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", + "//eval/eval:trace_step", + "//internal:casts", "//runtime:runtime_options", "//runtime/internal:issue_collector", "@com_google_absl//absl/algorithm:container", @@ -40,6 +44,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], ) @@ -51,10 +56,14 @@ cc_test( ":flat_expr_builder_extensions", ":resolver", "//base/ast_internal:expr", + "//common:casting", "//common:memory", + "//common:native_type", "//common:value", "//eval/eval:const_value_step", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", + "//eval/eval:function_step", "//internal:status_macros", "//internal:testing", "//runtime:function_registry", @@ -83,13 +92,16 @@ cc_library( "//base/ast_internal:ast_impl", "//base/ast_internal:expr", "//common:memory", + "//common:type", "//common:value", + "//eval/eval:attribute_trail", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", "//eval/eval:container_access_step", "//eval/eval:create_list_step", "//eval/eval:create_map_step", "//eval/eval:create_struct_step", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:function_step", "//eval/eval:ident_step", @@ -99,6 +111,7 @@ cc_library( "//eval/eval:select_step", "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", + "//eval/eval:trace_step", "//eval/public:ast_traverse_native", "//eval/public:ast_visitor_native", "//eval/public:cel_type_registry", @@ -108,6 +121,7 @@ cc_library( "//runtime:runtime_issue", "//runtime:runtime_options", "//runtime:type_registry", + "//runtime/internal:convert_constant", "//runtime/internal:issue_collector", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -119,6 +133,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], @@ -201,6 +216,7 @@ cc_test( "//internal:status_macros", "//internal:testing", "//parser", + "//runtime", "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -220,7 +236,9 @@ cc_library( deps = [ ":flat_expr_builder", "//base:ast", + "//common:native_type", "//eval/eval:cel_expression_flat_impl", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/public:cel_expression", "//extensions/protobuf:ast_converters", @@ -243,15 +261,32 @@ cc_test( ], deps = [ ":cel_expression_builder_flat_impl", + ":constant_folding", + ":regex_precompilation_optimization", + "//eval/eval:cel_expression_flat_impl", "//eval/public:activation", "//eval/public:builtin_func_registrar", + "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_value", + "//eval/public:portable_cel_function_adapter", + "//eval/public/containers:container_backed_map_impl", + "//eval/public/structs:cel_proto_wrapper", + "//eval/public/structs:protobuf_descriptor_type_provider", "//eval/public/testing:matchers", + "//extensions:bindings_ext", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", "//internal:testing", "//parser", + "//parser:macro", "//runtime:runtime_options", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -274,6 +309,7 @@ cc_library( "//common:memory", "//common:value", "//eval/eval:const_value_step", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//internal:status_macros", "//runtime:activation", @@ -462,14 +498,17 @@ cc_library( "//common:native_type", "//common:value", "//eval/eval:compiler_constant_step", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:regex_match_step", "//internal:casts", "//internal:status_macros", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@com_googlesource_code_re2//:re2", ], ) @@ -485,16 +524,18 @@ cc_test( "//base/ast_internal:ast_impl", "//base/ast_internal:expr", "//common:memory", - "//common:type", "//common:value", - "//eval/eval:cel_expression_flat_impl", "//eval/eval:evaluator_core", + "//eval/public:activation", "//eval/public:builtin_func_registrar", + "//eval/public:cel_expression", "//eval/public:cel_options", + "//eval/public:cel_value", "//internal:testing", "//parser", "//runtime:runtime_issue", "//runtime/internal:issue_collector", + "@com_google_absl//absl/status", "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", diff --git a/eval/compiler/cel_expression_builder_flat_impl.cc b/eval/compiler/cel_expression_builder_flat_impl.cc index 578ac961a..fc2395c2a 100644 --- a/eval/compiler/cel_expression_builder_flat_impl.cc +++ b/eval/compiler/cel_expression_builder_flat_impl.cc @@ -27,7 +27,9 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/ast.h" +#include "common/native_type.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/public/cel_expression.h" #include "extensions/protobuf/ast_converters.h" @@ -94,6 +96,14 @@ CelExpressionBuilderFlatImpl::CreateExpressionImpl( warnings->push_back(issue.ToStatus()); } } + if (flat_expr_builder_.options().max_recursion_depth != 0 && + impl.path().size() > 0 && + // mainline expression is exactly one recursive step. + impl.subexpressions().front().size() == 1 && + impl.path().front()->GetNativeTypeId() == + cel::NativeTypeId::For()) { + return CelExpressionRecursiveImpl::Create(std::move(impl)); + } return std::make_unique(std::move(impl)); } diff --git a/eval/compiler/cel_expression_builder_flat_impl_test.cc b/eval/compiler/cel_expression_builder_flat_impl_test.cc index 8935d6c92..d6810093d 100644 --- a/eval/compiler/cel_expression_builder_flat_impl_test.cc +++ b/eval/compiler/cel_expression_builder_flat_impl_test.cc @@ -18,17 +18,40 @@ // flat_expr_builder_test.cc for additional tests. #include "eval/compiler/cel_expression_builder_flat_impl.h" +#include +#include +#include +#include #include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/algorithm/container.h" #include "absl/status/status.h" +#include "eval/compiler/constant_folding.h" +#include "eval/compiler/regex_precompilation_optimization.h" +#include "eval/eval/cel_expression_flat_impl.h" #include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_value.h" +#include "eval/public/containers/container_backed_map_impl.h" +#include "eval/public/portable_cel_function_adapter.h" +#include "eval/public/structs/cel_proto_wrapper.h" +#include "eval/public/structs/protobuf_descriptor_type_provider.h" #include "eval/public/testing/matchers.h" +#include "extensions/bindings_ext.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" #include "internal/testing.h" +#include "parser/macro.h" #include "parser/parser.h" #include "runtime/runtime_options.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" namespace google::api::expr::runtime { @@ -38,10 +61,16 @@ using ::google::api::expr::v1alpha1::CheckedExpr; using ::google::api::expr::v1alpha1::Expr; using ::google::api::expr::v1alpha1::ParsedExpr; using ::google::api::expr::v1alpha1::SourceInfo; +using ::google::api::expr::parser::Macro; using ::google::api::expr::parser::Parse; +using ::google::api::expr::parser::ParseWithMacros; +using ::google::api::expr::test::v1::proto3::NestedTestAllTypes; +using ::google::api::expr::test::v1::proto3::TestAllTypes; using testing::_; using testing::Contains; using testing::HasSubstr; +using testing::IsNull; +using testing::NotNull; using cel::internal::StatusIs; TEST(CelExpressionBuilderFlatImplTest, Error) { @@ -69,6 +98,233 @@ TEST(CelExpressionBuilderFlatImplTest, ParsedExpr) { EXPECT_THAT(result, test::IsCelInt64(3)); } +struct RecursiveTestCase { + std::string test_name; + std::string expr; + test::CelValueMatcher matcher; +}; + +class RecursivePlanTest : public ::testing::TestWithParam { + protected: + absl::Status SetupBuilder(CelExpressionBuilderFlatImpl& builder) { + builder.GetTypeRegistry()->RegisterTypeProvider( + std::make_unique( + google::protobuf::DescriptorPool::generated_pool(), + google::protobuf::MessageFactory::generated_factory())); + CEL_RETURN_IF_ERROR(RegisterBuiltinFunctions(builder.GetRegistry())); + return builder.GetRegistry()->RegisterLazyFunction(CelFunctionDescriptor( + "LazilyBoundMult", false, + {CelValue::Type::kInt64, CelValue::Type::kInt64})); + } + + absl::Status SetupActivation(Activation& activation, google::protobuf::Arena* arena) { + activation.InsertValue("int_1", CelValue::CreateInt64(1)); + activation.InsertValue("string_abc", CelValue::CreateStringView("abc")); + activation.InsertValue("string_def", CelValue::CreateStringView("def")); + auto* map = google::protobuf::Arena::Create(arena); + CEL_RETURN_IF_ERROR( + map->Add(CelValue::CreateStringView("a"), CelValue::CreateInt64(1))); + CEL_RETURN_IF_ERROR( + map->Add(CelValue::CreateStringView("b"), CelValue::CreateInt64(2))); + activation.InsertValue("map_var", CelValue::CreateMap(map)); + auto* msg = google::protobuf::Arena::Create(arena); + msg->mutable_child()->mutable_payload()->set_single_int64(42); + activation.InsertValue("struct_var", + CelProtoWrapper::CreateMessage(msg, arena)); + + CEL_RETURN_IF_ERROR(activation.InsertFunction( + PortableBinaryFunctionAdapter::Create( + "LazilyBoundMult", false, + [](google::protobuf::Arena*, int64_t lhs, int64_t rhs) -> int64_t { + return lhs * rhs; + }))); + + return absl::OkStatus(); + } +}; + +absl::StatusOr ParseWithBind(absl::string_view cel) { + static const std::vector* kMacros = []() { + auto* result = new std::vector(Macro::AllMacros()); + absl::c_copy(cel::extensions::bindings_macros(), + std::back_inserter(*result)); + return result; + }(); + return ParseWithMacros(cel, *kMacros, ""); +} + +TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) { + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); + cel::RuntimeOptions options; + options.container = "google.api.expr.test.v1.proto3"; + google::protobuf::Arena arena; + // Unbounded. + options.max_recursion_depth = -1; + CelExpressionBuilderFlatImpl builder(options); + + ASSERT_OK(SetupBuilder(builder)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + NotNull()); + + Activation activation; + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test_case.matcher); +} + +TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) { + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); + cel::RuntimeOptions options; + options.container = "google.api.expr.test.v1.proto3"; + google::protobuf::Arena arena; + // Unbounded. + options.max_recursion_depth = -1; + options.enable_comprehension_list_append = true; + CelExpressionBuilderFlatImpl builder(options); + + ASSERT_OK(SetupBuilder(builder)); + + builder.flat_expr_builder().AddProgramOptimizer( + cel::runtime_internal::CreateConstantFoldingOptimizer( + cel::extensions::ProtoMemoryManagerRef(&arena))); + builder.flat_expr_builder().AddProgramOptimizer( + CreateRegexPrecompilationExtension(options.regex_max_program_size)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + NotNull()); + + Activation activation; + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test_case.matcher); +} + +TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) { + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); + cel::RuntimeOptions options; + options.container = "google.api.expr.test.v1.proto3"; + google::protobuf::Arena arena; + auto cb = [](int64_t id, const CelValue& value, google::protobuf::Arena* arena) { + return absl::OkStatus(); + }; + // Unbounded. + options.max_recursion_depth = -1; + options.enable_recursive_tracing = true; + CelExpressionBuilderFlatImpl builder(options); + + ASSERT_OK(SetupBuilder(builder)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + NotNull()); + + Activation activation; + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Trace(activation, &arena, cb)); + EXPECT_THAT(result, test_case.matcher); +} + +TEST_P(RecursivePlanTest, Disabled) { + google::protobuf::LinkMessageReflection(); + + const RecursiveTestCase& test_case = GetParam(); + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr)); + cel::RuntimeOptions options; + options.container = "google.api.expr.test.v1.proto3"; + google::protobuf::Arena arena; + // disabled. + options.max_recursion_depth = 0; + CelExpressionBuilderFlatImpl builder(options); + + ASSERT_OK(SetupBuilder(builder)); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, + builder.CreateExpression(&parsed_expr.expr(), + &parsed_expr.source_info())); + + EXPECT_THAT(dynamic_cast(plan.get()), + IsNull()); + + Activation activation; + ASSERT_OK(SetupActivation(activation, &arena)); + + ASSERT_OK_AND_ASSIGN(CelValue result, plan->Evaluate(activation, &arena)); + EXPECT_THAT(result, test_case.matcher); +} + +INSTANTIATE_TEST_SUITE_P( + RecursivePlanTest, RecursivePlanTest, + testing::ValuesIn(std::vector{ + {"constant", "'abc'", test::IsCelString("abc")}, + {"call", "1 + 2", test::IsCelInt64(3)}, + {"nested_call", "1 + 1 + 1 + 1", test::IsCelInt64(4)}, + {"and", "true && false", test::IsCelBool(false)}, + {"or", "true || false", test::IsCelBool(true)}, + {"ternary", "(true || false) ? 2 + 2 : 3 + 3", test::IsCelInt64(4)}, + {"create_list", "3 in [1, 2, 3]", test::IsCelBool(true)}, + {"create_list_complex", "3 in [2 / 2, 4 / 2, 6 / 2]", + test::IsCelBool(true)}, + {"ident", "int_1 == 1", test::IsCelBool(true)}, + {"ident_complex", "int_1 + 2 > 4 ? string_abc : string_def", + test::IsCelString("def")}, + {"select", "struct_var.child.payload.single_int64", + test::IsCelInt64(42)}, + {"nested_select", "[map_var.a, map_var.b].size() == 2", + test::IsCelBool(true)}, + {"map_index", "map_var['b']", test::IsCelInt64(2)}, + {"list_index", "[1, 2, 3][1]", test::IsCelInt64(2)}, + {"compre_exists", "[1, 2, 3, 4].exists(x, x == 3)", + test::IsCelBool(true)}, + {"compre_map", "8 in [1, 2, 3, 4].map(x, x * 2)", + test::IsCelBool(true)}, + {"map_var_compre_exists", "map_var.exists(key, key == 'b')", + test::IsCelBool(true)}, + {"map_compre_exists", "{'a': 1, 'b': 2}.exists(k, k == 'b')", + test::IsCelBool(true)}, + {"create_map", "{'a': 42, 'b': 0, 'c': 0}.size()", test::IsCelInt64(3)}, + {"create_struct", + "NestedTestAllTypes{payload: TestAllTypes{single_int64: " + "-42}}.payload.single_int64", + test::IsCelInt64(-42)}, + {"bind", R"(cel.bind(x, "1", x + x + x + x))", + test::IsCelString("1111")}, + {"nested_bind", R"(cel.bind(x, 20, cel.bind(y, 30, x + y)))", + test::IsCelInt64(50)}, + {"bind_with_comprehensions", + R"(cel.bind(x, [1, 2], cel.bind(y, x.map(z, z * 2), y.exists(z, z == 4))))", + test::IsCelBool(true)}, + {"shadowable_value", R"(list == type([]))", test::IsCelBool(true)}, + {"lazily_resolved_function", "LazilyBoundMult(123, 2) == 246", + test::IsCelBool(true)}, + {"re_matches", "matches(string_abc, '[ad][be][cf]')", + test::IsCelBool(true)}, + {"re_matches_receiver", + "(string_abc + string_def).matches(r'(123)?' + r'abc' + r'def')", + test::IsCelBool(true)}, + }), + + [](const testing::TestParamInfo& info) -> std::string { + return info.param.test_name; + }); + TEST(CelExpressionBuilderFlatImplTest, ParsedExprWithWarnings) { ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, Parse("1 + 2")); cel::RuntimeOptions options; diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index 8b4ef8951..948f6c4e7 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -33,6 +33,7 @@ #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" #include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "internal/status_macros.h" #include "runtime/activation.h" @@ -55,7 +56,8 @@ using ::cel::builtin::kAnd; using ::cel::builtin::kOr; using ::cel::builtin::kTernary; using ::cel::runtime_internal::ConvertConstant; - +using ::google::api::expr::runtime::CreateConstValueDirectStep; +using ::google::api::expr::runtime::CreateConstValueStep; using ::google::api::expr::runtime::EvaluationListener; using ::google::api::expr::runtime::ExecutionFrame; using ::google::api::expr::runtime::ExecutionPath; @@ -197,7 +199,7 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, // the current capacity. state_.value_stack().SetMaxSize(subplan.size()); - auto result = frame.Evaluate(EvaluationListener()); + auto result = frame.Evaluate(); // If this would be a runtime error, then don't adjust the program plan, but // rather allow the error to occur at runtime to preserve the evaluation // contract with non-constant folding use cases. @@ -211,9 +213,13 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context, } ExecutionPath new_plan; - CEL_ASSIGN_OR_RETURN(new_plan.emplace_back(), - google::api::expr::runtime::CreateConstValueStep( - std::move(value), node.id(), false)); + if (context.options().max_recursion_depth != 0) { + return context.ReplaceSubplan( + node, CreateConstValueDirectStep(std::move(value), node.id()), 1); + } + CEL_ASSIGN_OR_RETURN( + new_plan.emplace_back(), + CreateConstValueStep(std::move(value), node.id(), false)); return context.ReplaceSubplan(node, std::move(new_plan)); } diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index d4bf52899..007345e3a 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -371,8 +371,9 @@ TEST_F(UpdatedConstantFoldingTest, CreatesMap) { program_builder.ExitSubexpression(&value); // create map - ASSERT_OK_AND_ASSIGN( - step, CreateCreateStructStepForMap(create_map.struct_expr(), 3)); + ASSERT_OK_AND_ASSIGN(step, + CreateCreateStructStepForMap( + create_map.struct_expr().entries().size(), {}, 3)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); @@ -428,8 +429,9 @@ TEST_F(UpdatedConstantFoldingTest, CreatesInvalidMap) { program_builder.ExitSubexpression(&value); // create map - ASSERT_OK_AND_ASSIGN( - step, CreateCreateStructStepForMap(create_map.struct_expr(), 3)); + ASSERT_OK_AND_ASSIGN(step, + CreateCreateStructStepForMap( + create_map.struct_expr().entries().size(), {}, 3)); program_builder.AddStep(std::move(step)); program_builder.ExitSubexpression(&create_map); diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 9f217ef2f..e323c72db 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -16,6 +16,7 @@ #include "eval/compiler/flat_expr_builder.h" +#include #include #include #include @@ -39,6 +40,7 @@ #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "base/ast.h" @@ -46,16 +48,20 @@ #include "base/ast_internal/expr.h" #include "base/builtins.h" #include "common/memory.h" +#include "common/type.h" +#include "common/value.h" #include "common/value_manager.h" #include "common/values/legacy_value_manager.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/compiler/resolver.h" +#include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" #include "eval/eval/container_access_step.h" #include "eval/eval/create_list_step.h" #include "eval/eval/create_map_step.h" #include "eval/eval/create_struct_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/function_step.h" #include "eval/eval/ident_step.h" @@ -65,10 +71,12 @@ #include "eval/eval/select_step.h" #include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" +#include "eval/eval/trace_step.h" #include "eval/public/ast_traverse_native.h" #include "eval/public/ast_visitor_native.h" #include "eval/public/source_position_native.h" #include "internal/status_macros.h" +#include "runtime/internal/convert_constant.h" #include "runtime/internal/issue_collector.h" #include "runtime/runtime_issue.h" #include "runtime/runtime_options.h" @@ -80,9 +88,12 @@ namespace { using ::cel::Ast; using ::cel::RuntimeIssue; +using ::cel::StringValue; +using ::cel::Value; using ::cel::ValueManager; using ::cel::ast_internal::AstImpl; using ::cel::ast_internal::AstTraverse; +using ::cel::runtime_internal::ConvertConstant; using ::cel::runtime_internal::IssueCollector; // Forward declare to resolve circular dependency for short_circuiting visitors. @@ -337,6 +348,26 @@ class ComprehensionVisitor { size_t accu_slot_; }; +absl::flat_hash_set MakeOptionalIndicesSet( + const cel::ast_internal::CreateList& create_list_expr) { + absl::flat_hash_set optional_indices; + for (const auto& optional_index : create_list_expr.optional_indices()) { + optional_indices.insert(optional_index); + } + return optional_indices; +} + +absl::flat_hash_set MakeOptionalIndicesSet( + const cel::ast_internal::CreateStruct& create_struct_expr) { + absl::flat_hash_set optional_indices; + for (size_t i = 0; i < create_struct_expr.entries().size(); ++i) { + if (create_struct_expr.entries()[i].optional_entry()) { + optional_indices.insert(static_cast(i)); + } + } + return optional_indices; +} + class FlatExprVisitor : public cel::ast_internal::AstVisitor { public: FlatExprVisitor( @@ -403,6 +434,14 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { } } + auto* subexpression = program_builder_.current(); + if (subexpression != nullptr && options_.enable_recursive_tracing && + subexpression->IsRecursive()) { + auto program = subexpression->ExtractRecursiveProgram(); + subexpression->set_recursive_program( + std::make_unique(std::move(program.step)), program.depth); + } + program_builder_.ExitSubexpression(expr); if (!comprehension_stack_.empty() && @@ -420,7 +459,23 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { return; } - AddStep(CreateConstValueStep(*const_expr, expr->id(), value_factory_)); + absl::StatusOr converted_value = + ConvertConstant(*const_expr, value_factory_); + + if (!converted_value.ok()) { + SetProgressStatusError(converted_value.status()); + return; + } + + if (options_.max_recursion_depth > 0 || options_.max_recursion_depth < 0) { + SetRecursiveStep(CreateConstValueDirectStep( + std::move(converted_value).value(), expr->id()), + 1); + return; + } + + AddStep( + CreateConstValueStep(std::move(converted_value).value(), expr->id())); } struct SlotLookupResult { @@ -477,6 +532,9 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { // Attempt to resolve a select expression as a namespaced identifier for an // enum or type constant value. + absl::optional const_value; + int64_t select_root_id = -1; + while (!namespace_stack_.empty()) { const auto& select_node = namespace_stack_.front(); // Generate path in format ".....". @@ -487,23 +545,32 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { // qualified path present in the expression. Whether the identifier // can be resolved to a type instance depends on whether the option to // 'enable_qualified_type_identifiers' is set to true. - auto const_value = - resolver_.FindConstant(qualified_path, select_expr->id()); + const_value = resolver_.FindConstant(qualified_path, select_expr->id()); if (const_value) { - AddStep(CreateShadowableValueStep( - qualified_path, std::move(*const_value), select_expr->id())); resolved_select_expr_ = select_expr; + select_root_id = select_expr->id(); namespace_stack_.clear(); - return; + break; } namespace_stack_.pop_front(); } - // Attempt to resolve a simple identifier as an enum or type constant value. - auto const_value = resolver_.FindConstant(path, expr->id()); + if (!const_value) { + // Attempt to resolve a simple identifier as an enum or type constant + // value. + const_value = resolver_.FindConstant(path, expr->id()); + select_root_id = expr->id(); + } + if (const_value) { - AddStep( - CreateShadowableValueStep(path, std::move(*const_value), expr->id())); + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectShadowableValueStep( + path, std::move(*const_value), select_root_id), + 1); + return; + } + AddStep(CreateShadowableValueStep(path, std::move(*const_value), + select_root_id)); return; } @@ -511,16 +578,40 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { SlotLookupResult slot = LookupSlot(path); if (slot.subexpression >= 0) { - AddStep( - CreateCheckLazyInitStep(slot.slot, slot.subexpression, expr->id())); - AddStep(CreateAssignSlotStep(slot.slot)); + auto* subexpression = + program_builder_.GetExtractedSubexpression(slot.subexpression); + if (subexpression == nullptr) { + SetProgressStatusError( + absl::InternalError("bad subexpression reference")); + } + if (subexpression->IsRecursive()) { + const auto& program = subexpression->recursive_program(); + SetRecursiveStep( + CreateDirectLazyInitStep(slot.slot, program.step.get(), expr->id()), + program.depth + 1); + } else { + // Off by one since mainline expression will be index 0. + AddStep(CreateCheckLazyInitStep(slot.slot, slot.subexpression + 1, + expr->id())); + AddStep(CreateAssignSlotStep(slot.slot)); + } return; } else if (slot.slot >= 0) { - AddStep(CreateIdentStepForSlot(*ident_expr, slot.slot, expr->id())); + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectSlotIdentStep(ident_expr->name(), + slot.slot, expr->id()), + 1); + } else { + AddStep(CreateIdentStepForSlot(*ident_expr, slot.slot, expr->id())); + } return; } - - AddStep(CreateIdentStep(*ident_expr, expr->id())); + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectIdentStep(ident_expr->name(), expr->id()), + 1); + } else { + AddStep(CreateIdentStep(*ident_expr, expr->id())); + } } void PreVisitSelect(const cel::ast_internal::Select* select_expr, @@ -583,6 +674,26 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { return; } + auto depth = RecursionEligible(); + if (depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != 1) { + SetProgressStatusError(absl::InternalError( + "unexpected number of dependencies for select operation.")); + return; + } + StringValue field = + value_factory_.CreateUncheckedStringValue(select_expr->field()); + + SetRecursiveStep( + CreateDirectSelectStep(std::move(deps[0]), std::move(field), + select_expr->test_only(), expr->id(), + options_.enable_empty_wrapper_null_unboxing, + enable_optional_types_), + *depth + 1); + return; + } + AddStep(CreateSelectStep(*select_expr, expr->id(), options_.enable_empty_wrapper_null_unboxing, value_factory_, enable_optional_types_)); @@ -630,6 +741,217 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { } } + absl::optional RecursionEligible() { + if (program_builder_.current() == nullptr) { + return absl::nullopt; + } + absl::optional depth = + program_builder_.current()->RecursiveDependencyDepth(); + if (!depth.has_value()) { + // one or more of the dependencies isn't eligible. + return depth; + } + if (options_.max_recursion_depth < 0 || + *depth < options_.max_recursion_depth) { + return depth; + } + return absl::nullopt; + } + + std::vector> + ExtractRecursiveDependencies() { + // Must check recursion eligibility before calling. + ABSL_DCHECK(program_builder_.current() != nullptr); + + return program_builder_.current()->ExtractRecursiveDependencies(); + } + + void MaybeMakeTernaryRecursive(const cel::ast_internal::Expr* expr) { + if (options_.max_recursion_depth == 0) { + return; + } + if (expr->call_expr().args().size() != 3) { + SetProgressStatusError(absl::InvalidArgumentError( + "unexpected number of args for builtin ternary")); + } + + const cel::ast_internal::Expr* condition_expr = + &expr->call_expr().args()[0]; + const cel::ast_internal::Expr* left_expr = &expr->call_expr().args()[1]; + const cel::ast_internal::Expr* right_expr = &expr->call_expr().args()[2]; + + auto* condition_plan = program_builder_.GetSubexpression(condition_expr); + auto* left_plan = program_builder_.GetSubexpression(left_expr); + auto* right_plan = program_builder_.GetSubexpression(right_expr); + + int max_depth = 0; + if (condition_plan == nullptr || !condition_plan->IsRecursive()) { + return; + } + max_depth = std::max(max_depth, condition_plan->recursive_program().depth); + + if (left_plan == nullptr || !left_plan->IsRecursive()) { + return; + } + max_depth = std::max(max_depth, left_plan->recursive_program().depth); + + if (right_plan == nullptr || !right_plan->IsRecursive()) { + return; + } + max_depth = std::max(max_depth, right_plan->recursive_program().depth); + + if (options_.max_recursion_depth >= 0 && + max_depth >= options_.max_recursion_depth) { + return; + } + + SetRecursiveStep( + CreateDirectTernaryStep(condition_plan->ExtractRecursiveProgram().step, + left_plan->ExtractRecursiveProgram().step, + right_plan->ExtractRecursiveProgram().step, + expr->id(), options_.short_circuiting), + max_depth + 1); + } + + void MaybeMakeShortcircuitRecursive(const cel::ast_internal::Expr* expr, + bool is_or) { + if (options_.max_recursion_depth == 0) { + return; + } + if (expr->call_expr().args().size() != 2) { + SetProgressStatusError(absl::InvalidArgumentError( + "unexpected number of args for builtin ternary")); + } + const cel::ast_internal::Expr* left_expr = &expr->call_expr().args()[0]; + const cel::ast_internal::Expr* right_expr = &expr->call_expr().args()[1]; + + auto* left_plan = program_builder_.GetSubexpression(left_expr); + auto* right_plan = program_builder_.GetSubexpression(right_expr); + + int max_depth = 0; + if (left_plan == nullptr || !left_plan->IsRecursive()) { + return; + } + max_depth = std::max(max_depth, left_plan->recursive_program().depth); + + if (right_plan == nullptr || !right_plan->IsRecursive()) { + return; + } + max_depth = std::max(max_depth, right_plan->recursive_program().depth); + + if (options_.max_recursion_depth >= 0 && + max_depth >= options_.max_recursion_depth) { + return; + } + + if (is_or) { + SetRecursiveStep( + CreateDirectOrStep(left_plan->ExtractRecursiveProgram().step, + right_plan->ExtractRecursiveProgram().step, + expr->id(), options_.short_circuiting), + max_depth + 1); + } else { + SetRecursiveStep( + CreateDirectAndStep(left_plan->ExtractRecursiveProgram().step, + right_plan->ExtractRecursiveProgram().step, + expr->id(), options_.short_circuiting), + max_depth + 1); + } + } + + void MaybeMakeBindRecursive( + const cel::ast_internal::Expr* expr, + const cel::ast_internal::Comprehension* comprehension, size_t accu_slot) { + if (options_.max_recursion_depth == 0) { + return; + } + + auto* result_plan = + program_builder_.GetSubexpression(&comprehension->result()); + + if (result_plan == nullptr || !result_plan->IsRecursive()) { + return; + } + + int result_depth = result_plan->recursive_program().depth; + + if (options_.max_recursion_depth > 0 && + result_depth >= options_.max_recursion_depth) { + return; + } + + auto program = result_plan->ExtractRecursiveProgram(); + SetRecursiveStep( + CreateDirectBindStep(accu_slot, std::move(program.step), expr->id()), + result_depth + 1); + } + + void MaybeMakeComprehensionRecursive( + const cel::ast_internal::Expr* expr, + const cel::ast_internal::Comprehension* comprehension, size_t iter_slot, + size_t accu_slot) { + if (options_.max_recursion_depth == 0) { + return; + } + + auto* accu_plan = + program_builder_.GetSubexpression(&comprehension->accu_init()); + + if (accu_plan == nullptr || !accu_plan->IsRecursive()) { + return; + } + + auto* range_plan = + program_builder_.GetSubexpression(&comprehension->iter_range()); + + if (range_plan == nullptr || !range_plan->IsRecursive()) { + return; + } + + auto* loop_plan = + program_builder_.GetSubexpression(&comprehension->loop_step()); + + if (loop_plan == nullptr || !loop_plan->IsRecursive()) { + return; + } + + auto* condition_plan = + program_builder_.GetSubexpression(&comprehension->loop_condition()); + + if (condition_plan == nullptr || !condition_plan->IsRecursive()) { + return; + } + + auto* result_plan = + program_builder_.GetSubexpression(&comprehension->result()); + + if (result_plan == nullptr || !result_plan->IsRecursive()) { + return; + } + + int max_depth = 0; + max_depth = std::max(max_depth, accu_plan->recursive_program().depth); + max_depth = std::max(max_depth, range_plan->recursive_program().depth); + max_depth = std::max(max_depth, loop_plan->recursive_program().depth); + max_depth = std::max(max_depth, condition_plan->recursive_program().depth); + max_depth = std::max(max_depth, result_plan->recursive_program().depth); + + if (options_.max_recursion_depth > 0 && + max_depth >= options_.max_recursion_depth) { + return; + } + + auto step = CreateDirectComprehensionStep( + iter_slot, accu_slot, range_plan->ExtractRecursiveProgram().step, + accu_plan->ExtractRecursiveProgram().step, + loop_plan->ExtractRecursiveProgram().step, + condition_plan->ExtractRecursiveProgram().step, + result_plan->ExtractRecursiveProgram().step, options_.short_circuiting, + expr->id()); + + SetRecursiveStep(std::move(step), max_depth + 1); + } + // Invoked after all child nodes are processed. void PostVisitCall(const cel::ast_internal::Call* call_expr, const cel::ast_internal::Expr* expr, @@ -642,11 +964,31 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { if (cond_visitor) { cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); + if (call_expr->function() == cel::builtin::kTernary) { + MaybeMakeTernaryRecursive(expr); + } else if (call_expr->function() == cel::builtin::kOr) { + MaybeMakeShortcircuitRecursive(expr, /* is_or= */ true); + } else if (call_expr->function() == cel::builtin::kAnd) { + MaybeMakeShortcircuitRecursive(expr, /* is_or= */ false); + } return; } // Special case for "_[_]". if (call_expr->function() == cel::builtin::kIndex) { + auto depth = RecursionEligible(); + if (depth.has_value()) { + auto args = ExtractRecursiveDependencies(); + if (args.size() != 2) { + SetProgressStatusError(absl::InvalidArgumentError( + "unexpected number of args for builtin index operator")); + } + SetRecursiveStep(CreateDirectContainerAccessStep( + std::move(args[0]), std::move(args[1]), + enable_optional_types_, expr->id()), + *depth + 1); + return; + } AddStep(CreateContainerAccessStep(*call_expr, expr->id(), enable_optional_types_)); return; @@ -679,7 +1021,7 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { } } - AddStep(CreateResolvedFunctionStep(call_expr, expr, function)); + AddResolvedFunctionStep(call_expr, expr, function); } void PreVisitComprehension( @@ -851,7 +1193,6 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { } } - // Nothing to do. void PostVisitTarget(const cel::ast_internal::Expr* expr, const cel::ast_internal::SourcePosition*) override { if (!progress_status_.ok()) { @@ -876,10 +1217,27 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { comprehension_stack_.back(); if (comprehension.is_optimizable_list_append && &(comprehension.comprehension->accu_init()) == expr) { - AddStep(CreateCreateMutableListStep(*list_expr, expr->id())); + if (options_.max_recursion_depth != 0) { + SetRecursiveStep(CreateDirectMutableListStep(expr->id()), 1); + return; + } + AddStep(CreateMutableListStep(expr->id())); return; } } + absl::optional depth = RecursionEligible(); + if (depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != list_expr->elements().size()) { + SetProgressStatusError(absl::InternalError( + "Unexpected number of plan elements for CreateList expr")); + return; + } + auto step = CreateDirectListStep( + std::move(deps), MakeOptionalIndicesSet(*list_expr), expr->id()); + SetRecursiveStep(std::move(step), *depth + 1); + return; + } AddStep(CreateCreateListStep(*list_expr, expr->id())); } @@ -900,25 +1258,57 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { ValidateOrError(entry.has_map_key(), "Map entry missing key"); ValidateOrError(entry.has_value(), "Map entry missing value"); } - AddStep(CreateCreateStructStepForMap(*struct_expr, expr->id())); + auto depth = RecursionEligible(); + if (depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != 2 * struct_expr->entries().size()) { + SetProgressStatusError(absl::InternalError( + "Unexpected number of plan elements for CreateStruct expr")); + return; + } + auto step = CreateDirectCreateMapStep( + std::move(deps), MakeOptionalIndicesSet(*struct_expr), expr->id()); + SetRecursiveStep(std::move(step), *depth + 1); + return; + } + AddStep(CreateCreateStructStepForMap(struct_expr->entries().size(), + MakeOptionalIndicesSet(*struct_expr), + expr->id())); return; } // If the message name is not empty, then the message name must be resolved // within the container, and if a descriptor is found, then a proto message // creation step will be created. - auto status_or_maybe_type = resolver_.FindType(message_name, expr->id()); - if (!status_or_maybe_type.ok()) { - SetProgressStatusError(status_or_maybe_type.status()); + auto status_or_resolved_fields = + ResolveCreateStructFields(*struct_expr, expr->id()); + if (!status_or_resolved_fields.ok()) { + SetProgressStatusError(status_or_resolved_fields.status()); return; } - if (ValidateOrError(status_or_maybe_type->has_value(), - "Invalid struct creation: missing type info for '", - message_name, "'")) { - AddStep(CreateCreateStructStepForStruct( - *struct_expr, std::move((*status_or_maybe_type)->first), expr->id(), - value_factory())); + std::string resolved_name = + std::move(status_or_resolved_fields.value().first); + std::vector fields = + std::move(status_or_resolved_fields.value().second); + + auto depth = RecursionEligible(); + if (depth.has_value()) { + auto deps = ExtractRecursiveDependencies(); + if (deps.size() != struct_expr->entries().size()) { + SetProgressStatusError(absl::InternalError( + "Unexpected number of plan elements for CreateStruct expr")); + return; + } + auto step = CreateDirectCreateStructStep( + std::move(resolved_name), std::move(fields), std::move(deps), + MakeOptionalIndicesSet(*struct_expr), expr->id()); + SetRecursiveStep(std::move(step), *depth + 1); + return; } + + AddStep(CreateCreateStructStep(std::move(resolved_name), std::move(fields), + MakeOptionalIndicesSet(*struct_expr), + expr->id())); } absl::Status progress_status() const { return progress_status_; } @@ -933,10 +1323,10 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { suppressed_branches_.insert(expr); } - absl::StatusOr> CreateResolvedFunctionStep( - const cel::ast_internal::Call* call_expr, - const cel::ast_internal::Expr* expr, absl::string_view function, - bool collect_issues = true) { + void AddResolvedFunctionStep(const cel::ast_internal::Call* call_expr, + const cel::ast_internal::Expr* expr, + absl::string_view function, + bool collect_issues = true) { // Establish the search criteria for a given function. bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args().size() + (receiver_style ? 1 : 0); @@ -947,8 +1337,18 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { auto lazy_overloads = resolver_.FindLazyOverloads( function, call_expr->has_target(), arguments_matcher, expr->id()); if (!lazy_overloads.empty()) { - return CreateFunctionStep(*call_expr, expr->id(), - std::move(lazy_overloads)); + auto depth = RecursionEligible(); + if (depth.has_value()) { + auto args = program_builder_.current()->ExtractRecursiveDependencies(); + SetRecursiveStep(CreateDirectLazyFunctionStep( + expr->id(), *call_expr, std::move(args), + std::move(lazy_overloads)), + *depth + 1); + return; + } + AddStep(CreateFunctionStep(*call_expr, expr->id(), + std::move(lazy_overloads))); + return; } // Second, search for eagerly defined function overloads. @@ -965,16 +1365,29 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { status = issue_collector_.AddIssue(RuntimeIssue::CreateWarning( std::move(status), RuntimeIssue::ErrorCode::kNoMatchingOverload)); } - CEL_RETURN_IF_ERROR(status); + if (!status.ok()) { + SetProgressStatusError(status); + return; + } + } + auto recursion_depth = RecursionEligible(); + if (recursion_depth.has_value()) { + // Nonnull while active -- nullptr indicates logic error elsewhere in the + // builder. + ABSL_DCHECK(program_builder_.current() != nullptr); + auto args = program_builder_.current()->ExtractRecursiveDependencies(); + SetRecursiveStep( + CreateDirectFunctionStep(expr->id(), *call_expr, std::move(args), + std::move(overloads)), + *recursion_depth + 1); + return; } - return CreateFunctionStep(*call_expr, expr->id(), std::move(overloads)); + AddStep(CreateFunctionStep(*call_expr, expr->id(), std::move(overloads))); } - absl::StatusOr> CreateChainedOptionalStep( - const cel::ast_internal::Call* call_expr, - const cel::ast_internal::Expr* expr) { - return CreateResolvedFunctionStep(call_expr, expr, call_expr->function(), - false); + void AddChainedOptionalStep(const cel::ast_internal::Call* call_expr, + const cel::ast_internal::Expr* expr) { + AddResolvedFunctionStep(call_expr, expr, call_expr->function(), false); } void AddStep(absl::StatusOr> step) { @@ -991,6 +1404,18 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { } } + void SetRecursiveStep(std::unique_ptr step, int depth) { + if (!progress_status_.ok() || PlanningSuppressed()) { + return; + } + if (program_builder_.current() == nullptr) { + SetProgressStatusError(absl::InternalError( + "CEL AST traversal out of order in flat_expr_builder.")); + return; + } + program_builder_.current()->set_recursive_program(std::move(step), depth); + } + void SetProgressStatusError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; @@ -1070,14 +1495,48 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor { return absl::InternalError("Failed to extract subexpression"); } - // off by one since mainline expression is handled separately. - record.subexpression = index + 1; + record.subexpression = index; record.visitor->MarkAccuInitExtracted(); return absl::OkStatus(); } + // Resolve the name of the message type being created and the names of set + // fields. + absl::StatusOr>> + ResolveCreateStructFields( + const cel::ast_internal::CreateStruct& create_struct_expr, + int64_t expr_id) { + absl::string_view ast_name = create_struct_expr.message_name(); + + absl::optional> type; + CEL_ASSIGN_OR_RETURN(type, resolver_.FindType(ast_name, expr_id)); + + if (!type.has_value()) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid struct creation: missing type info for '", ast_name, "'")); + } + + std::string resolved_name = std::move(type).value().first; + + std::vector fields; + fields.reserve(create_struct_expr.entries().size()); + for (const auto& entry : create_struct_expr.entries()) { + CEL_ASSIGN_OR_RETURN(auto field, + value_factory().FindStructTypeFieldByName( + resolved_name, entry.field_key())); + if (!field.has_value()) { + return absl::InvalidArgumentError( + absl::StrCat("Invalid message creation: field '", entry.field_key(), + "' not found in '", resolved_name, "'")); + } + fields.push_back(entry.field_key()); + } + + return std::make_pair(std::move(resolved_name), std::move(fields)); + } + const Resolver& resolver_; ValueManager& value_factory_; absl::Status progress_status_; @@ -1187,24 +1646,22 @@ void BinaryCondVisitor::PostVisitTarget(const cel::ast_internal::Expr* expr) { } void BinaryCondVisitor::PostVisit(const cel::ast_internal::Expr* expr) { - absl::StatusOr> step; switch (cond_) { case BinaryCond::kAnd: - step = CreateAndStep(expr->id()); + visitor_->AddStep(CreateAndStep(expr->id())); break; case BinaryCond::kOr: - step = CreateOrStep(expr->id()); + visitor_->AddStep(CreateOrStep(expr->id())); break; case BinaryCond::kOptionalOr: - step = visitor_->CreateChainedOptionalStep(&expr->call_expr(), expr); + visitor_->AddChainedOptionalStep(&expr->call_expr(), expr); break; case BinaryCond::kOptionalOrValue: - step = visitor_->CreateChainedOptionalStep(&expr->call_expr(), expr); + visitor_->AddChainedOptionalStep(&expr->call_expr(), expr); break; default: ABSL_UNREACHABLE(); } - visitor_->AddStep(std::move(step)); if (short_circuiting_) { // If short-circuiting is enabled, point the conditional jump past the // boolean operator step. @@ -1398,7 +1855,15 @@ void ComprehensionVisitor::PostVisitArgTrivial( } } -void ComprehensionVisitor::PostVisit(const cel::ast_internal::Expr* expr) {} +void ComprehensionVisitor::PostVisit(const cel::ast_internal::Expr* expr) { + if (is_trivial_) { + visitor_->MaybeMakeBindRecursive(expr, &expr->comprehension_expr(), + accu_slot_); + return; + } + visitor_->MaybeMakeComprehensionRecursive(expr, &expr->comprehension_expr(), + iter_slot_, accu_slot_); +} // Flattens the expression table into the end of the mainline expression vector // and returns an index to the individual sub expressions. diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index d545e0e81..4ab422b25 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -42,6 +42,7 @@ class FlatExprBuilder { const CelTypeRegistry& type_registry, const cel::RuntimeOptions& options) : options_(options), + container_(options.container), function_registry_(function_registry), type_registry_(type_registry.InternalGetModernRegistry()) {} @@ -49,6 +50,7 @@ class FlatExprBuilder { const cel::TypeRegistry& type_registry, const cel::RuntimeOptions& options) : options_(options), + container_(options.container), function_registry_(function_registry), type_registry_(type_registry) {} diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc index 2a4909359..5db07b1ab 100644 --- a/eval/compiler/flat_expr_builder_comprehensions_test.cc +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -42,6 +42,7 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "parser/parser.h" +#include "runtime/runtime.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { @@ -53,9 +54,25 @@ using ::google::api::expr::v1alpha1::ParsedExpr; using testing::HasSubstr; using cel::internal::StatusIs; -TEST(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { - cel::RuntimeOptions options; - options.enable_comprehension_list_append = true; +class CelExpressionBuilderFlatImplComprehensionsTest + : public testing::TestWithParam { + public: + CelExpressionBuilderFlatImplComprehensionsTest() = default; + + bool enable_recursive_planning() { return GetParam(); } + + cel::RuntimeOptions GetRuntimeOptions() { + cel::RuntimeOptions options; + if (enable_recursive_planning()) { + options.max_recursion_depth = -1; + } + options.enable_comprehension_list_append = true; + return options; + } +}; + +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, @@ -72,9 +89,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, NestedComp) { EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { - cel::RuntimeOptions options; - options.enable_comprehension_list_append = true; +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, parser::Parse("[1, 2].map(x, x * 2)")); @@ -94,9 +110,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, MapComp) { test::EqualsCelValue(CelValue::CreateInt64(4))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { - cel::RuntimeOptions options; - options.enable_comprehension_list_append = true; +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, @@ -112,9 +127,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneTrue) { EXPECT_THAT(result, test::IsCelBool(true)); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { - cel::RuntimeOptions options; - options.enable_comprehension_list_append = true; +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); ASSERT_OK_AND_ASSIGN(auto parsed_expr, @@ -130,8 +144,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, ExistsOneFalse) { EXPECT_THAT(result, test::IsCelBool(false)); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { - cel::RuntimeOptions options; +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { + cel::RuntimeOptions options = GetRuntimeOptions(); options.unknown_processing = UnknownProcessingOptions::kAttributeAndFunction; CelExpressionBuilderFlatImpl builder(options); @@ -167,8 +181,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, ListCompWithUnknowns) { testing::Eq(1)); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - InvalidComprehensionWithRewrite) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + InvalidComprehensionWithRewrite) { CheckedExpr expr; // The rewrite step which occurs when an identifier gets a more qualified name // from the reference map has the potential to make invalid comprehensions @@ -195,7 +209,7 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, } })pb", &expr); - cel::RuntimeOptions options; + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); EXPECT_THAT(builder.CreateExpression(&expr).status(), @@ -204,8 +218,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("Invalid empty expression")))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithConcatVulernability) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithConcatVulernability) { CheckedExpr expr; // The comprehension loop step performs an unsafe concatenation of the // accumulation variable with itself or one of its children. @@ -248,7 +262,7 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, })pb", &expr); - cel::RuntimeOptions options; + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); @@ -258,8 +272,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithListVulernability) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithListVulernability) { CheckedExpr expr; // The comprehension google::protobuf::TextFormat::ParseFromString( @@ -292,7 +306,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, )pb", &expr); - CelExpressionBuilderFlatImpl builder; + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -301,8 +316,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithStructVulernability) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithStructVulernability) { CheckedExpr expr; // The comprehension loop step builds a deeply nested struct which expands // exponentially. @@ -348,7 +363,7 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, )pb", &expr); - cel::RuntimeOptions options; + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); @@ -358,8 +373,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithNestedComprehensionResultVulernability) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionResultVulernability) { CheckedExpr expr; // The nested comprehension performs an unsafe concatenation on the parent // accumulator variable within its 'result' expression. @@ -416,7 +431,7 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, )pb", &expr); - cel::RuntimeOptions options; + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); @@ -426,8 +441,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithNestedComprehensionLoopStepVulernability) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepVulernability) { CheckedExpr expr; // The nested comprehension performs an unsafe concatenation on the parent // accumulator variable within its 'loop_step'. @@ -463,7 +478,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, )pb", &expr); - CelExpressionBuilderFlatImpl builder; + cel::RuntimeOptions options = GetRuntimeOptions(); + CelExpressionBuilderFlatImpl builder(options); builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -472,8 +488,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithNestedComprehensionLoopStepVulernabilityResult) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepVulernabilityResult) { CheckedExpr expr; // The nested comprehension performs an unsafe concatenation on the parent // accumulator. @@ -513,6 +529,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, } )pb", &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder; builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); @@ -522,8 +540,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, - ComprehensionWithNestedComprehensionLoopStepIterRangeVulnerability) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + ComprehensionWithNestedComprehensionLoopStepIterRangeVulnerability) { CheckedExpr expr; // The nested comprehension unsafely modifies the parent accumulator // (outer_accu) being used as a iterable range @@ -558,6 +576,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, } )pb", &expr); + + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder; builder.flat_expr_builder().AddProgramOptimizer( CreateComprehensionVulnerabilityCheck()); @@ -567,7 +587,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, HasSubstr("memory exhaustion vulnerability"))); } -TEST(CelExpressionBuilderFlatImplComprehensionsTest, InvalidBindComprehension) { +TEST_P(CelExpressionBuilderFlatImplComprehensionsTest, + InvalidBindComprehension) { ParsedExpr expr; // Trivial comprehensions (such as cel.bind), are optimized by skipping the // planning for the loop step, however the planner will still warn if the @@ -598,7 +619,8 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, InvalidBindComprehension) { } })pb", &expr)); - cel::RuntimeOptions options; + + cel::RuntimeOptions options = GetRuntimeOptions(); CelExpressionBuilderFlatImpl builder(options); ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); @@ -609,6 +631,13 @@ TEST(CelExpressionBuilderFlatImplComprehensionsTest, InvalidBindComprehension) { HasSubstr("Unexpected iter_var access in trivial comprehension"))); } +INSTANTIATE_TEST_SUITE_P(TestSuite, + CelExpressionBuilderFlatImplComprehensionsTest, + testing::Bool(), + [](const testing::TestParamInfo& info) { + return info.param ? "recursive" : "default"; + }); + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_extensions.cc b/eval/compiler/flat_expr_builder_extensions.cc index add19c1b9..11e132321 100644 --- a/eval/compiler/flat_expr_builder_extensions.cc +++ b/eval/compiler/flat_expr_builder_extensions.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "eval/compiler/flat_expr_builder_extensions.h" +#include #include #include #include @@ -25,8 +26,10 @@ #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/ast_internal/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { @@ -40,6 +43,8 @@ Subexpression::Subexpression(const cel::ast_internal::Expr* self, size_t Subexpression::ComputeSize() const { if (IsFlattened()) { return flattened_elements().size(); + } else if (IsRecursive()) { + return 1; } std::vector to_expand{this}; size_t size = 0; @@ -49,6 +54,9 @@ size_t Subexpression::ComputeSize() const { if (expr->IsFlattened()) { size += expr->flattened_elements().size(); continue; + } else if (expr->IsRecursive()) { + size += 1; + continue; } for (const auto& elem : expr->elements()) { if (auto* child = absl::get_if>(&elem); @@ -62,6 +70,47 @@ size_t Subexpression::ComputeSize() const { return size; } +absl::optional Subexpression::RecursiveDependencyDepth() const { + auto* tree = absl::get_if(&program_); + int depth = 0; + if (tree == nullptr) { + return absl::nullopt; + } + for (const auto& element : *tree) { + auto* subexpression = + absl::get_if>(&element); + if (subexpression == nullptr) { + return absl::nullopt; + } + if (!(*subexpression)->IsRecursive()) { + return absl::nullopt; + } + depth = std::max(depth, (*subexpression)->recursive_program().depth); + } + return depth; +} + +std::vector> +Subexpression::ExtractRecursiveDependencies() const { + auto* tree = absl::get_if(&program_); + std::vector> dependencies; + if (tree == nullptr) { + return {}; + } + for (const auto& element : *tree) { + auto* subexpression = + absl::get_if>(&element); + if (subexpression == nullptr) { + return {}; + } + if (!(*subexpression)->IsRecursive()) { + return {}; + } + dependencies.push_back((*subexpression)->ExtractRecursiveProgram().step); + } + return dependencies; +} + Subexpression::~Subexpression() { auto map_ptr = subprogram_map_.lock(); if (map_ptr == nullptr) { @@ -97,6 +146,7 @@ std::unique_ptr Subexpression::ExtractChild( int Subexpression::CalculateOffset(int base, int target) const { ABSL_DCHECK(!IsFlattened()); + ABSL_DCHECK(!IsRecursive()); ABSL_DCHECK_GE(base, 0); ABSL_DCHECK_GE(target, 0); ABSL_DCHECK_LE(base, elements().size()); @@ -146,10 +196,14 @@ void Subexpression::Flatten() { Record top = flatten_stack.back(); flatten_stack.pop_back(); size_t offset = top.offset; - auto& subexpr = top.subexpr; + auto* subexpr = top.subexpr; if (subexpr->IsFlattened()) { absl::c_move(subexpr->flattened_elements(), std::back_inserter(flat)); continue; + } else if (subexpr->IsRecursive()) { + flat.push_back(std::make_unique( + std::move(subexpr->ExtractRecursiveProgram().step), + subexpr->self_->id())); } size_t size = subexpr->elements().size(); size_t i = offset; @@ -160,9 +214,10 @@ void Subexpression::Flatten() { flatten_stack.push_back({subexpr, i + 1}); flatten_stack.push_back({child->get(), 0}); break; - } else { - flat.push_back( - absl::get>(std::move(element))); + } else if (auto* step = + absl::get_if>(&element); + step != nullptr) { + flat.push_back(std::move(*step)); } } if (i >= size && subexpr != this) { @@ -173,6 +228,13 @@ void Subexpression::Flatten() { program_ = std::move(flat); } +Subexpression::RecursiveProgram Subexpression::ExtractRecursiveProgram() { + ABSL_DCHECK(IsRecursive()); + auto result = std::move(absl::get(program_)); + program_.emplace>(); + return result; +} + bool Subexpression::ExtractTo( std::vector>& out) { if (!IsFlattened()) { @@ -343,6 +405,19 @@ absl::Status PlannerContext::ReplaceSubplan(const cel::ast_internal::Expr& node, return absl::OkStatus(); } +absl::Status PlannerContext::ReplaceSubplan( + const cel::ast_internal::Expr& node, + std::unique_ptr step, int depth) { + auto* subexpression = program_builder_.GetSubexpression(&node); + if (subexpression == nullptr) { + return absl::InternalError( + "attempted to update program step for untracked expr node"); + } + + subexpression->set_recursive_program(std::move(step), depth); + return absl::OkStatus(); +} + absl::Status PlannerContext::AddSubplanStep( const cel::ast_internal::Expr& node, std::unique_ptr step) { auto* subexpression = program_builder_.GetSubexpression(&node); diff --git a/eval/compiler/flat_expr_builder_extensions.h b/eval/compiler/flat_expr_builder_extensions.h index 23ea49a5a..10f5513ce 100644 --- a/eval/compiler/flat_expr_builder_extensions.h +++ b/eval/compiler/flat_expr_builder_extensions.h @@ -33,13 +33,19 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/optional.h" #include "absl/types/variant.h" #include "base/ast.h" #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" +#include "common/native_type.h" +#include "common/value.h" #include "common/value_manager.h" #include "eval/compiler/resolver.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/trace_step.h" +#include "internal/casts.h" #include "runtime/internal/issue_collector.h" #include "runtime/runtime_options.h" @@ -76,7 +82,15 @@ class ProgramBuilder { using Element = absl::variant, std::unique_ptr>; + using TreePlan = std::vector; + using FlattenedPlan = std::vector>; + public: + struct RecursiveProgram { + std::unique_ptr step; + int depth; + }; + ~Subexpression(); // Not copyable or movable. @@ -86,17 +100,24 @@ class ProgramBuilder { Subexpression& operator=(Subexpression&&) = delete; // Add a program step at the current end of the subexpression. - void AddStep(std::unique_ptr step) { + bool AddStep(std::unique_ptr step) { + if (IsRecursive()) { + return false; + } + if (IsFlattened()) { flattened_elements().push_back(std::move(step)); - } else { - elements().push_back(std::move(step)); + return true; } + + elements().push_back({std::move(step)}); + return true; } void AddSubexpression(std::unique_ptr expr) { ABSL_DCHECK(!IsFlattened()); - elements().push_back(std::move(expr)); + ABSL_DCHECK(!IsRecursive()); + elements().push_back({std::move(expr)}); } // Accessor for elements (either simple steps or subexpressions). @@ -104,12 +125,12 @@ class ProgramBuilder { // Value is undefined if in the expression has already been flattened. std::vector& elements() { ABSL_DCHECK(!IsFlattened()); - return absl::get>(program_); + return absl::get(program_); } const std::vector& elements() const { ABSL_DCHECK(!IsFlattened()); - return absl::get>(program_); + return absl::get(program_); } // Accessor for program steps. @@ -117,15 +138,34 @@ class ProgramBuilder { // Value is undefined if in the expression has not yet been flattened. std::vector>& flattened_elements() { ABSL_DCHECK(IsFlattened()); - return absl::get>>( - program_); + return absl::get(program_); } const std::vector>& flattened_elements() const { ABSL_DCHECK(IsFlattened()); - return absl::get>>( - program_); + return absl::get(program_); + } + + void set_recursive_program(std::unique_ptr step, + int depth) { + program_ = RecursiveProgram{std::move(step), depth}; + } + + const RecursiveProgram& recursive_program() const { + ABSL_DCHECK(IsRecursive()); + return absl::get(program_); + } + + absl::optional RecursiveDependencyDepth() const; + + std::vector> + ExtractRecursiveDependencies() const; + + RecursiveProgram ExtractRecursiveProgram(); + + bool IsRecursive() const { + return absl::holds_alternative(program_); } // Compute the current number of program steps in this subexpression and @@ -150,8 +190,7 @@ class ProgramBuilder { void Flatten(); bool IsFlattened() const { - return absl::holds_alternative< - std::vector>>(program_); + return absl::holds_alternative(program_); } // Extract a flattened subexpression into the given vector. Transferring @@ -169,9 +208,7 @@ class ProgramBuilder { // // This adds complexity, but supports swapping to a flat representation as // needed. - absl::variant, - std::vector>> - program_; + absl::variant program_; const cel::ast_internal::Expr* self_; absl::Nullable parent_; @@ -224,6 +261,17 @@ class ProgramBuilder { absl::Nullable GetSubexpression( const cel::ast_internal::Expr* expr); + // Return the extracted subexpression mapped to the given index. + // + // Returns nullptr if the mapping doesn't exist + absl::Nullable GetExtractedSubexpression(size_t index) { + if (index >= extracted_subexpressions_.size()) { + return nullptr; + } + + return extracted_subexpressions_[index].get(); + } + // Return index to the extracted subexpression. // // Returns -1 if the subexpression is not found. @@ -245,6 +293,31 @@ class ProgramBuilder { std::shared_ptr subprogram_map_; }; +// Attempt to downcast a specific type of recursive step. +template +const Subclass* TryDowncastDirectStep(const DirectExpressionStep* step) { + if (step == nullptr) { + return nullptr; + } + + auto type_id = step->GetNativeTypeId(); + if (type_id == cel::NativeTypeId::For()) { + const auto* trace_step = cel::internal::down_cast(step); + auto deps = trace_step->GetDependencies(); + if (!deps.has_value() || deps->size() != 1) { + return nullptr; + } + step = deps->at(0); + type_id = step->GetNativeTypeId(); + } + + if (type_id == cel::NativeTypeId::For()) { + return cel::internal::down_cast(step); + } + + return nullptr; +} + // Class representing FlatExpr internals exposed to extensions. class PlannerContext { public: @@ -259,6 +332,8 @@ class PlannerContext { issue_collector_(issue_collector), program_builder_(program_builder) {} + ProgramBuilder& program_builder() { return program_builder_; } + // Returns true if the subplan is inspectable. // // If false, the node is not mapped to a subexpression in the program builder. @@ -286,6 +361,14 @@ class PlannerContext { absl::Status ReplaceSubplan(const cel::ast_internal::Expr& node, ExecutionPath path); + // Replace the subplan associated with node with a new recursive subplan. + // + // This operation clears any existing plan to which removes the + // expr->program mapping for any descendants. + absl::Status ReplaceSubplan(const cel::ast_internal::Expr& node, + std::unique_ptr step, + int depth); + // Extend the current subplan with the given expression step. absl::Status AddSubplanStep(const cel::ast_internal::Expr& node, std::unique_ptr step); diff --git a/eval/compiler/flat_expr_builder_extensions_test.cc b/eval/compiler/flat_expr_builder_extensions_test.cc index 8f0bb27cf..2d84add88 100644 --- a/eval/compiler/flat_expr_builder_extensions_test.cc +++ b/eval/compiler/flat_expr_builder_extensions_test.cc @@ -19,11 +19,14 @@ #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" #include "common/memory.h" +#include "common/native_type.h" #include "common/value_manager.h" #include "common/values/legacy_value_manager.h" #include "eval/compiler/resolver.h" #include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/function_step.h" #include "internal/status_macros.h" #include "internal/testing.h" #include "runtime/function_registry.h" @@ -40,8 +43,11 @@ using ::cel::ast_internal::Expr; using ::cel::runtime_internal::IssueCollector; using testing::ElementsAre; using testing::IsEmpty; +using testing::Optional; using cel::internal::StatusIs; +using Subexpression = ProgramBuilder::Subexpression; + class PlannerContextTest : public testing::Test { public: PlannerContextTest() @@ -468,5 +474,54 @@ TEST_F(ProgramBuilderTest, ExtractToRequiresFlatten) { UniquePtrHolds(step_ptrs.a))); } +TEST_F(ProgramBuilderTest, Recursive) { + Expr a; + Expr b; + Expr c; + + ProgramBuilder program_builder; + + program_builder.EnterSubexpression(&a); + program_builder.EnterSubexpression(&b); + program_builder.current()->set_recursive_program( + CreateConstValueDirectStep(value_factory_.GetNullValue()), 1); + program_builder.ExitSubexpression(&b); + program_builder.EnterSubexpression(&c); + program_builder.current()->set_recursive_program( + CreateConstValueDirectStep(value_factory_.GetNullValue()), 1); + program_builder.ExitSubexpression(&c); + + ASSERT_FALSE(program_builder.current()->IsFlattened()); + ASSERT_FALSE(program_builder.current()->IsRecursive()); + ASSERT_TRUE(program_builder.GetSubexpression(&b)->IsRecursive()); + ASSERT_TRUE(program_builder.GetSubexpression(&c)->IsRecursive()); + + EXPECT_EQ(program_builder.GetSubexpression(&b)->recursive_program().depth, 1); + EXPECT_EQ(program_builder.GetSubexpression(&c)->recursive_program().depth, 1); + + cel::ast_internal::Call call_expr; + call_expr.set_function("_==_"); + call_expr.mutable_args().emplace_back(); + call_expr.mutable_args().emplace_back(); + + auto max_depth = program_builder.current()->RecursiveDependencyDepth(); + + EXPECT_THAT(max_depth, Optional(1)); + + auto deps = program_builder.current()->ExtractRecursiveDependencies(); + + program_builder.current()->set_recursive_program( + CreateDirectFunctionStep(-1, call_expr, std::move(deps), {}), + *max_depth + 1); + + program_builder.ExitSubexpression(&a); + + auto path = program_builder.FlattenMain(); + + ASSERT_THAT(path, testing::SizeIs(1)); + EXPECT_TRUE(path[0]->GetNativeTypeId() == + cel::NativeTypeId::For()); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index ca0445a02..dcdd273bd 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -1343,7 +1343,7 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { Expr expr; SourceInfo source_info; // [1, 2].all(x, x > 0) - google::protobuf::TextFormat::ParseFromString(R"( + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(R"( comprehension_expr { iter_var: "k" accu_var: "accu" @@ -1369,12 +1369,12 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { } iter_range { list_expr { - { const_expr { int64_value: 1 } } - { const_expr { int64_value: 2 } } + elements { const_expr { int64_value: 1 } } + elements { const_expr { int64_value: 2 } } } } })", - &expr); + &expr)); cel::RuntimeOptions options; options.comprehension_max_iterations = 1; diff --git a/eval/compiler/regex_precompilation_optimization.cc b/eval/compiler/regex_precompilation_optimization.cc index 60ba7ad1d..b34ede0e7 100644 --- a/eval/compiler/regex_precompilation_optimization.cc +++ b/eval/compiler/regex_precompilation_optimization.cc @@ -19,7 +19,9 @@ #include #include #include +#include +#include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" @@ -31,20 +33,26 @@ #include "common/value.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/compiler_constant_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/regex_match_step.h" #include "internal/casts.h" #include "internal/status_macros.h" +#include "re2/re2.h" namespace google::api::expr::runtime { namespace { -using cel::NativeTypeId; -using cel::ast_internal::AstImpl; -using cel::ast_internal::Call; -using cel::ast_internal::Expr; -using cel::ast_internal::Reference; -using cel::internal::down_cast; +using ::cel::Cast; +using ::cel::InstanceOf; +using ::cel::NativeTypeId; +using ::cel::StringValue; +using ::cel::Value; +using ::cel::ast_internal::AstImpl; +using ::cel::ast_internal::Call; +using ::cel::ast_internal::Expr; +using ::cel::ast_internal::Reference; +using ::cel::internal::down_cast; using ReferenceMap = absl::flat_hash_map; @@ -131,54 +139,121 @@ class RegexPrecompilationOptimization : public ProgramOptimizer { return absl::OkStatus(); } + ProgramBuilder::Subexpression* subexpression = + context.program_builder().GetSubexpression(&node); + if (subexpression == nullptr || subexpression->IsFlattened()) { + // Already modified, can't update further. + return absl::OkStatus(); + } + const Call& call_expr = node.call_expr(); const Expr& pattern_expr = call_expr.args().back(); absl::optional pattern = - GetConstantString(context, pattern_expr); + GetConstantString(context, subexpression, node, pattern_expr); if (!pattern.has_value()) { return absl::OkStatus(); } - CEL_ASSIGN_OR_RETURN(auto program, regex_program_builder_.BuildRegexProgram( - std::move(pattern).value())); + CEL_ASSIGN_OR_RETURN( + std::shared_ptr regex_program, + regex_program_builder_.BuildRegexProgram(std::move(pattern).value())); const Expr& subject_expr = call_expr.has_target() ? call_expr.target() : call_expr.args().front(); - if (context.GetSubplan(subject_expr).empty()) { - // This subexpression was already optimized, nothing to do. - return absl::OkStatus(); - } - - CEL_ASSIGN_OR_RETURN(ExecutionPath new_plan, - context.ExtractSubplan(subject_expr)); - CEL_ASSIGN_OR_RETURN(new_plan.emplace_back(), - CreateRegexMatchStep(std::move(program), node.id())); - - return context.ReplaceSubplan(node, std::move(new_plan)); + return RewritePlan(context, subexpression, node, subject_expr, + std::move(regex_program)); } private: absl::optional GetConstantString( - PlannerContext& context, const cel::ast_internal::Expr& expr) const { - if (expr.has_const_expr() && expr.const_expr().has_string_value()) { - return expr.const_expr().string_value(); + PlannerContext& context, + absl::Nonnull subexpression, + const cel::ast_internal::Expr& call_expr, + const cel::ast_internal::Expr& re_expr) const { + if (re_expr.has_const_expr() && re_expr.const_expr().has_string_value()) { + return re_expr.const_expr().string_value(); } - ExecutionPathView re_plan = context.GetSubplan(expr); - if (re_plan.size() == 1 && re_plan[0]->GetNativeTypeId() == - NativeTypeId::For()) { - const auto& constant = - down_cast(*re_plan[0]); - if (constant.value()->Is()) { - return constant.value()->As().ToString(); + absl::optional constant; + if (subexpression->IsRecursive()) { + const auto& program = subexpression->recursive_program(); + auto deps = program.step->GetDependencies(); + if (deps.has_value() && deps->size() == 2) { + const auto* re_plan = + TryDowncastDirectStep(deps->at(1)); + if (re_plan != nullptr) { + constant = re_plan->value(); + } + } + } else { + // otherwise stack-machine program. + ExecutionPathView re_plan = context.GetSubplan(re_expr); + if (re_plan.size() == 1 && + re_plan[0]->GetNativeTypeId() == + NativeTypeId::For()) { + constant = + down_cast(re_plan[0].get())->value(); } } + if (constant.has_value() && InstanceOf(*constant)) { + return Cast(*constant).ToString(); + } + return absl::nullopt; } + absl::Status RewritePlan( + PlannerContext& context, + absl::Nonnull subexpression, + const Expr& call, const Expr& subject, + std::shared_ptr regex_program) { + if (subexpression->IsRecursive()) { + return RewriteRecursivePlan(subexpression, call, subject, + std::move(regex_program)); + } + return RewriteStackMachinePlan(context, call, subject, + std::move(regex_program)); + } + + absl::Status RewriteRecursivePlan( + absl::Nonnull subexpression, + const Expr& call, const Expr& subject, + std::shared_ptr regex_program) { + auto program = subexpression->ExtractRecursiveProgram(); + auto deps = program.step->ExtractDependencies(); + if (!deps.has_value() || deps->size() != 2) { + // Possibly already const-folded, put the plan back. + subexpression->set_recursive_program(std::move(program.step), + program.depth); + return absl::OkStatus(); + } + subexpression->set_recursive_program( + CreateDirectRegexMatchStep(call.id(), std::move(deps->at(0)), + std::move(regex_program)), + program.depth); + return absl::OkStatus(); + } + + absl::Status RewriteStackMachinePlan( + PlannerContext& context, const Expr& call, const Expr& subject, + std::shared_ptr regex_program) { + if (context.GetSubplan(subject).empty()) { + // This subexpression was already optimized, nothing to do. + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(ExecutionPath new_plan, + context.ExtractSubplan(subject)); + CEL_ASSIGN_OR_RETURN( + new_plan.emplace_back(), + CreateRegexMatchStep(std::move(regex_program), call.id())); + + return context.ReplaceSubplan(call, std::move(new_plan)); + } + const ReferenceMap& reference_map_; RegexProgramBuilder regex_program_builder_; }; diff --git a/eval/compiler/regex_precompilation_optimization_test.cc b/eval/compiler/regex_precompilation_optimization_test.cc index 90f8f3e51..6e2b02031 100644 --- a/eval/compiler/regex_precompilation_optimization_test.cc +++ b/eval/compiler/regex_precompilation_optimization_test.cc @@ -14,26 +14,29 @@ #include "eval/compiler/regex_precompilation_optimization.h" +#include #include +#include #include +#include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/status/status.h" #include "base/ast_internal/ast_impl.h" #include "base/ast_internal/expr.h" #include "common/memory.h" -#include "common/type_factory.h" -#include "common/type_manager.h" -#include "common/value_manager.h" #include "common/values/legacy_value_manager.h" #include "eval/compiler/cel_expression_builder_flat_impl.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/flat_expr_builder.h" #include "eval/compiler/flat_expr_builder_extensions.h" -#include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/evaluator_core.h" +#include "eval/public/activation.h" #include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_expression.h" #include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" #include "internal/testing.h" #include "parser/parser.h" #include "runtime/internal/issue_collector.h" @@ -47,10 +50,11 @@ using ::cel::RuntimeIssue; using ::cel::ast_internal::CheckedExpr; using ::cel::runtime_internal::IssueCollector; using ::google::api::expr::parser::Parse; +using testing::ElementsAre; namespace exprpb = google::api::expr::v1alpha1; -class RegexPrecompilationExtensionTest : public testing::Test { +class RegexPrecompilationExtensionTest : public testing::TestWithParam { public: RegexPrecompilationExtensionTest() : type_registry_(*builder_.GetTypeRegistry()), @@ -61,6 +65,10 @@ class RegexPrecompilationExtensionTest : public testing::Test { type_registry_.InternalGetModernRegistry(), value_factory_, type_registry_.resolveable_enums()), issue_collector_(RuntimeIssue::Severity::kError) { + if (EnableRecursivePlanning()) { + options_.max_recursion_depth = -1; + options_.enable_recursive_tracing = true; + } options_.enable_regex = true; options_.regex_max_program_size = 100; options_.enable_regex_precompilation = true; @@ -71,7 +79,18 @@ class RegexPrecompilationExtensionTest : public testing::Test { ASSERT_OK(RegisterBuiltinFunctions(&function_registry_, options_)); } + bool EnableRecursivePlanning() { return GetParam(); } + protected: + CelEvaluationListener RecordStringValues() { + return [this](int64_t, const CelValue& value, google::protobuf::Arena*) { + if (value.IsString()) { + string_values_.push_back(std::string(value.StringOrDie().value())); + } + return absl::OkStatus(); + }; + } + CelExpressionBuilderFlatImpl builder_; CelTypeRegistry& type_registry_; CelFunctionRegistry& function_registry_; @@ -80,9 +99,10 @@ class RegexPrecompilationExtensionTest : public testing::Test { cel::common_internal::LegacyValueManager value_factory_; Resolver resolver_; IssueCollector issue_collector_; + std::vector string_values_; }; -TEST_F(RegexPrecompilationExtensionTest, SmokeTest) { +TEST_P(RegexPrecompilationExtensionTest, SmokeTest) { ProgramOptimizerFactory factory = CreateRegexPrecompilationExtension(options_.regex_max_program_size); ExecutionPath path; @@ -96,20 +116,7 @@ TEST_F(RegexPrecompilationExtensionTest, SmokeTest) { factory(context, ast_impl)); } -MATCHER_P(ExpressionPlanSizeIs, size, "") { - // This is brittle, but the most direct way to test that the plan - // was optimized. - const std::unique_ptr& plan = arg; - - const CelExpressionFlatImpl* impl = - dynamic_cast(plan.get()); - - if (impl == nullptr) return false; - *result_listener << "got size " << impl->flat_expression().path().size(); - return impl->flat_expression().path().size() == size; -} - -TEST_F(RegexPrecompilationExtensionTest, OptimizeableExpression) { +TEST_P(RegexPrecompilationExtensionTest, OptimizeableExpression) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); @@ -125,10 +132,15 @@ TEST_F(RegexPrecompilationExtensionTest, OptimizeableExpression) { ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); - EXPECT_THAT(plan, ExpressionPlanSizeIs(2)); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); } -TEST_F(RegexPrecompilationExtensionTest, OptimizeParsedExpr) { +TEST_P(RegexPrecompilationExtensionTest, OptimizeParsedExpr) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); @@ -139,10 +151,15 @@ TEST_F(RegexPrecompilationExtensionTest, OptimizeParsedExpr) { std::unique_ptr plan, builder_.CreateExpression(&expr.expr(), &expr.source_info())); - EXPECT_THAT(plan, ExpressionPlanSizeIs(2)); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); } -TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeNonConstRegex) { +TEST_P(RegexPrecompilationExtensionTest, DoesNotOptimizeNonConstRegex) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); @@ -158,10 +175,16 @@ TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeNonConstRegex) { ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); - EXPECT_THAT(plan, ExpressionPlanSizeIs(3)); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + activation.InsertValue("input_re", CelValue::CreateStringView("input_re")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123", "input_re")); } -TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { +TEST_P(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); @@ -177,7 +200,12 @@ TEST_F(RegexPrecompilationExtensionTest, DoesNotOptimizeCompoundExpr) { ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); - EXPECT_THAT(plan, ExpressionPlanSizeIs(5)) << expr.DebugString(); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123", "abc", "def", "abcdef")); } class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { @@ -192,7 +220,7 @@ class RegexConstFoldInteropTest : public RegexPrecompilationExtensionTest { google::protobuf::Arena arena_; }; -TEST_F(RegexConstFoldInteropTest, StringConstantOptimizeable) { +TEST_P(RegexConstFoldInteropTest, StringConstantOptimizeable) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); @@ -207,11 +235,15 @@ TEST_F(RegexConstFoldInteropTest, StringConstantOptimizeable) { ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); - EXPECT_THAT(plan, ExpressionPlanSizeIs(2)) << expr.DebugString(); + ASSERT_OK(plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); } -TEST_F(RegexConstFoldInteropTest, WrongTypeNotOptimized) { +TEST_P(RegexConstFoldInteropTest, WrongTypeNotOptimized) { builder_.flat_expr_builder().AddProgramOptimizer( CreateRegexPrecompilationExtension(options_.regex_max_program_size)); @@ -227,8 +259,22 @@ TEST_F(RegexConstFoldInteropTest, WrongTypeNotOptimized) { ASSERT_OK_AND_ASSIGN(std::unique_ptr plan, builder_.CreateExpression(&expr)); - EXPECT_THAT(plan, ExpressionPlanSizeIs(3)) << expr.DebugString(); + Activation activation; + google::protobuf::Arena arena; + activation.InsertValue("input", CelValue::CreateStringView("input123")); + + ASSERT_OK_AND_ASSIGN(CelValue result, + plan->Trace(activation, &arena, RecordStringValues())); + EXPECT_THAT(string_values_, ElementsAre("input123")); + EXPECT_TRUE(result.IsError()); + EXPECT_TRUE(CheckNoMatchingOverloadError(result)); } +INSTANTIATE_TEST_SUITE_P(RegexPrecompilationExtensionTest, + RegexPrecompilationExtensionTest, testing::Bool()); + +INSTANTIATE_TEST_SUITE_P(RegexConstFoldInteropTest, RegexConstFoldInteropTest, + testing::Bool()); + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 36b6e0e35..281b3f24d 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -20,6 +20,15 @@ licenses(["notice"]) exports_files(["LICENSE"]) +package_group( + name = "internal_eval_visibility", + packages = [ + "//eval/...", + "//extensions", + "//runtime/internal", + ], +) + cc_library( name = "evaluator_core", srcs = [ @@ -37,12 +46,12 @@ cc_library( "//common:native_type", "//common:type", "//common:value", - "//internal:status_macros", "//runtime", "//runtime:activation_interface", "//runtime:managed_value_factory", "//runtime:runtime_options", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", @@ -63,14 +72,25 @@ cc_library( "cel_expression_flat_impl.h", ], deps = [ + ":attribute_trail", + ":comprehension_slots", + ":direct_expression_step", ":evaluator_core", + "//common:native_type", "//common:value", "//eval/internal:adapter_activation_impl", "//eval/internal:interop", + "//eval/public:base_activation", "//eval/public:cel_expression", "//eval/public:cel_value", "//extensions/protobuf:memory_manager", + "//internal:casts", + "//internal:status_macros", + "//runtime:managed_value_factory", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], ) @@ -84,6 +104,7 @@ cc_library( ":attribute_trail", "//common:value", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/types:optional", ], ) @@ -155,12 +176,15 @@ cc_library( "const_value_step.h", ], deps = [ + ":attribute_trail", ":compiler_constant_step", + ":direct_expression_step", ":evaluator_core", "//base/ast_internal:expr", "//common:value", "//internal:status_macros", "//runtime/internal:convert_constant", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) @@ -174,6 +198,9 @@ cc_library( "container_access_step.h", ], deps = [ + ":attribute_trail", + ":attribute_utility", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:attributes", @@ -187,6 +214,7 @@ cc_library( "//internal:casts", "//internal:number", "//internal:status_macros", + "//runtime/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -200,11 +228,17 @@ cc_library( srcs = ["regex_match_step.cc"], hdrs = ["regex_match_step.h"], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", + "//common:casting", "//common:value", + "//internal:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", "@com_googlesource_code_re2//:re2", ], ) @@ -220,9 +254,11 @@ cc_library( deps = [ ":attribute_trail", ":comprehension_slots", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base/ast_internal:expr", + "//common:value", "//eval/internal:errors", "//internal:status_macros", "@com_google_absl//absl/status", @@ -241,12 +277,14 @@ cc_library( ], deps = [ ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:function", "//base:function_descriptor", "//base:kind", "//base/ast_internal:expr", + "//common:casting", "//common:value", "//eval/internal:errors", "//internal:status_macros", @@ -254,6 +292,7 @@ cc_library( "//runtime:function_overload_reference", "//runtime:function_provider", "//runtime:function_registry", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -271,6 +310,8 @@ cc_library( "select_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:kind", @@ -279,9 +320,11 @@ cc_library( "//common:native_type", "//common:value", "//eval/internal:errors", + "//eval/public:ast_visitor", "//internal:casts", "//internal:status_macros", "//runtime:runtime_options", + "@com_google_absl//absl/log:absl_log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -298,6 +341,9 @@ cc_library( "create_list_step.h", ], deps = [ + ":attribute_trail", + ":attribute_utility", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base/ast_internal:expr", @@ -321,12 +367,13 @@ cc_library( "create_struct_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base/ast_internal:expr", "//common:casting", "//common:memory", - "//common:type", "//common:value", "//internal:status_macros", "@com_google_absl//absl/container:flat_hash_set", @@ -346,6 +393,8 @@ cc_library( "create_map_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base/ast_internal:expr", @@ -393,13 +442,20 @@ cc_library( "logic_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:builtins", + "//common:casting", "//common:value", + "//common:value_kind", "//eval/internal:errors", + "//internal:status_macros", + "//runtime/internal:errors", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -415,16 +471,23 @@ cc_library( deps = [ ":attribute_trail", ":comprehension_slots", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:attributes", "//base:kind", "//common:casting", "//common:value", + "//common:value_kind", "//eval/internal:errors", + "//eval/public:cel_attribute", "//internal:status_macros", "//runtime/internal:mutable_list_impl", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", @@ -439,15 +502,26 @@ cc_test( ], deps = [ ":cel_expression_flat_impl", + ":comprehension_slots", ":comprehension_step", + ":const_value_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:data", + "//base/ast_internal:expr", + "//common:value", + "//common:value_testing", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", + "//extensions/protobuf:memory_manager", + "//internal:status_macros", "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", + "//runtime:runtime_options", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -516,6 +590,7 @@ cc_test( deps = [ ":cel_expression_flat_impl", ":container_access_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:builtins", @@ -526,6 +601,7 @@ cc_test( "//eval/public:cel_expression", "//eval/public:cel_options", "//eval/public:cel_value", + "//eval/public:unknown_set", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", @@ -567,13 +643,20 @@ cc_test( "ident_step_test.cc", ], deps = [ + ":attribute_trail", ":cel_expression_flat_impl", ":evaluator_core", ":ident_step", "//base:data", + "//common:memory", + "//common:value", "//eval/public:activation", + "//eval/public:cel_attribute", "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", "//runtime:runtime_options", + "@com_google_absl//absl/status", ], ) @@ -586,11 +669,14 @@ cc_test( deps = [ ":cel_expression_flat_impl", ":const_value_step", + ":direct_expression_step", ":evaluator_core", ":function_step", ":ident_step", + "//base:builtins", "//base:data", "//base/ast_internal:expr", + "//common:kind", "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_attribute", @@ -602,10 +688,17 @@ cc_test( "//eval/public/structs:cel_proto_wrapper", "//eval/public/testing:matchers", "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", + "//runtime", + "//runtime:function_overload_reference", + "//runtime:function_registry", + "//runtime:managed_value_factory", "//runtime:runtime_options", + "//runtime:standard_functions", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -616,17 +709,33 @@ cc_test( "logic_step_test.cc", ], deps = [ + ":attribute_trail", ":cel_expression_flat_impl", + ":const_value_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", ":logic_step", + "//base:attributes", "//base:data", + "//base/ast_internal:expr", + "//common:casting", + "//common:value", "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", "//runtime:runtime_options", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_protobuf//:protobuf", ], ) @@ -637,13 +746,18 @@ cc_test( "select_step_test.cc", ], deps = [ + ":attribute_trail", ":cel_expression_flat_impl", + ":const_value_step", ":evaluator_core", ":ident_step", ":select_step", + "//base:attributes", "//base:data", - "//common:type", + "//base/ast_internal:expr", + "//common:legacy_value", "//common:value", + "//common:value_testing", "//eval/public:activation", "//eval/public:cel_attribute", "//eval/public:cel_value", @@ -655,13 +769,17 @@ cc_test( "//eval/testutil:test_extensions_cc_proto", "//eval/testutil:test_message_cc_proto", "//extensions/protobuf:memory_manager", + "//internal:proto_matchers", "//internal:status_macros", "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", "//runtime:runtime_options", - "//testutil:util", + "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], @@ -674,11 +792,20 @@ cc_test( "create_list_step_test.cc", ], deps = [ + ":attribute_trail", ":cel_expression_flat_impl", ":const_value_step", ":create_list_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", + "//base:attributes", + "//base:data", + "//base/ast_internal:expr", + "//common:casting", + "//common:memory", + "//common:value", + "//common:value_testing", "//eval/internal:interop", "//eval/public:activation", "//eval/public:cel_attribute", @@ -686,7 +813,10 @@ cc_test( "//eval/public/testing:matchers", "//internal:status_macros", "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", "//runtime:runtime_options", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], @@ -701,6 +831,7 @@ cc_test( deps = [ ":cel_expression_flat_impl", ":create_struct_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:data", @@ -738,6 +869,7 @@ cc_test( deps = [ ":cel_expression_flat_impl", ":create_map_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", "//base:data", @@ -790,10 +922,14 @@ cc_library( ":attribute_trail", "//base:attributes", "//base:function_descriptor", + "//base:function_result", "//base:function_result_set", "//base/internal:unknown_set", "//common:value", "//eval/internal:errors", + "//internal:status_macros", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -828,11 +964,16 @@ cc_library( "ternary_step.h", ], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//base:builtins", + "//common:casting", "//common:value", "//eval/internal:errors", + "//internal:status_macros", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) @@ -844,17 +985,30 @@ cc_test( "ternary_step_test.cc", ], deps = [ + ":attribute_trail", ":cel_expression_flat_impl", + ":const_value_step", + ":direct_expression_step", ":evaluator_core", ":ident_step", ":ternary_step", + "//base:attributes", "//base:data", + "//base/ast_internal:expr", + "//common:casting", + "//common:value", "//eval/public:activation", + "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_set", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "//internal:testing", + "//runtime:activation", + "//runtime:managed_value_factory", "//runtime:runtime_options", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", ], ) @@ -864,10 +1018,14 @@ cc_library( srcs = ["shadowable_value_step.cc"], hdrs = ["shadowable_value_step.h"], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", "//common:value", "//internal:status_macros", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], ) @@ -897,8 +1055,13 @@ cc_library( srcs = ["compiler_constant_step.cc"], hdrs = ["compiler_constant_step.h"], deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", ":expression_step_base", "//common:native_type", + "//common:value", + "@com_google_absl//absl/status", ], ) @@ -925,9 +1088,15 @@ cc_library( srcs = ["lazy_init_step.cc"], hdrs = ["lazy_init_step.h"], deps = [ + ":attribute_trail", + ":direct_expression_step", ":evaluator_core", ":expression_step_base", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", + "@com_google_googleapis//google/api/expr/v1alpha1:value_cc_proto", ], ) @@ -949,3 +1118,34 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "direct_expression_step", + srcs = ["direct_expression_step.cc"], + hdrs = ["direct_expression_step.h"], + deps = [ + ":attribute_trail", + ":evaluator_core", + "//common:native_type", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "trace_step", + hdrs = ["trace_step.h"], + deps = [ + ":attribute_trail", + ":direct_expression_step", + ":evaluator_core", + "//common:native_type", + "//common:value", + "//internal:status_macros", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/eval/eval/attribute_utility.cc b/eval/eval/attribute_utility.cc index dcf48f806..9a67a4408 100644 --- a/eval/eval/attribute_utility.cc +++ b/eval/eval/attribute_utility.cc @@ -1,19 +1,37 @@ #include "eval/eval/attribute_utility.h" +#include +#include #include +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "base/attribute.h" #include "base/attribute_set.h" +#include "base/function_descriptor.h" +#include "base/function_result.h" +#include "base/function_result_set.h" #include "base/internal/unknown_set.h" #include "common/value.h" +#include "eval/eval/attribute_trail.h" #include "eval/internal/errors.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { using ::cel::AttributeSet; +using ::cel::Cast; using ::cel::ErrorValue; +using ::cel::FunctionResult; +using ::cel::FunctionResultSet; +using ::cel::InstanceOf; using ::cel::UnknownValue; +using ::cel::Value; using ::cel::base_internal::UnknownSet; +using Accumulator = AttributeUtility::Accumulator; + bool AttributeUtility::CheckForMissingAttribute( const AttributeTrail& trail) const { if (trail.empty()) { @@ -79,14 +97,29 @@ absl::optional AttributeUtility::MergeUnknowns( result_set->unknown_attributes(), result_set->unknown_function_results()); } +UnknownValue AttributeUtility::MergeUnknownValues( + const UnknownValue& left, const UnknownValue& right) const { + // Empty unknown value may be used as a sentinel in some tests so need to + // distinguish unset (nullopt) and empty(engaged empty value). + AttributeSet attributes; + FunctionResultSet function_results; + attributes.Add(left.attribute_set()); + function_results.Add(left.function_result_set()); + attributes.Add(right.attribute_set()); + function_results.Add(right.function_result_set()); + + return value_factory_.CreateUnknownValue(std::move(attributes), + std::move(function_results)); +} + // Creates merged UnknownAttributeSet. // Scans over the args collection, determines if there matches to unknown // patterns, merges attributes together with those from initial_set // (if initial_set is not null). // Returns pointer to merged set or nullptr, if there were no sets to merge. -cel::AttributeSet AttributeUtility::CheckForUnknowns( +AttributeSet AttributeUtility::CheckForUnknowns( absl::Span args, bool use_partial) const { - cel::AttributeSet attribute_set; + AttributeSet attribute_set; for (const auto& trail : args) { if (CheckForUnknown(trail, use_partial)) { @@ -148,7 +181,39 @@ UnknownValue AttributeUtility::CreateUnknownSet( const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, absl::Span args) const { return value_factory_.CreateUnknownValue( - cel::FunctionResultSet(cel::FunctionResult(fn_descriptor, expr_id))); + FunctionResultSet(FunctionResult(fn_descriptor, expr_id))); +} + +void AttributeUtility::Add(Accumulator& a, const cel::UnknownValue& v) const { + a.attribute_set_.Add(v.attribute_set()); + a.function_result_set_.Add(v.function_result_set()); +} + +void AttributeUtility::Add(Accumulator& a, const AttributeTrail& attr) const { + a.attribute_set_.Add(attr.attribute()); +} + +void Accumulator::Add(const UnknownValue& value) { + unknown_present_ = true; + parent_.Add(*this, value); +} + +void Accumulator::Add(const AttributeTrail& attr) { parent_.Add(*this, attr); } + +void Accumulator::MaybeAdd(const Value& v) { + if (InstanceOf(v)) { + Add(Cast(v)); + } +} + +bool Accumulator::IsEmpty() const { + return !unknown_present_ && attribute_set_.empty() && + function_result_set_.empty(); +} + +cel::UnknownValue Accumulator::Build() && { + return parent_.value_manager().CreateUnknownValue( + std::move(attribute_set_), std::move(function_result_set_)); } } // namespace google::api::expr::runtime diff --git a/eval/eval/attribute_utility.h b/eval/eval/attribute_utility.h index 05d15df1d..aeb2d9b12 100644 --- a/eval/eval/attribute_utility.h +++ b/eval/eval/attribute_utility.h @@ -1,7 +1,9 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ +#include "absl/status/statusor.h" #include "absl/types/span.h" +#include "base/attribute.h" #include "base/attribute_set.h" #include "base/function_descriptor.h" #include "base/function_result_set.h" @@ -18,6 +20,39 @@ namespace google::api::expr::runtime { // Neither moveable nor copyable. class AttributeUtility { public: + class Accumulator { + public: + Accumulator(const Accumulator&) = delete; + Accumulator& operator=(const Accumulator&) = delete; + Accumulator(Accumulator&&) = delete; + Accumulator& operator=(Accumulator&&) = delete; + + // Add to the accumulated unknown attributes and functions. + void Add(const cel::UnknownValue& v); + void Add(const AttributeTrail& attr); + + // Add to the accumulated set of unknowns if value is UnknownValue. + void MaybeAdd(const cel::Value& v); + + bool IsEmpty() const; + + cel::UnknownValue Build() &&; + + private: + explicit Accumulator(const AttributeUtility& parent) + : parent_(parent), unknown_present_(false) {} + + friend class AttributeUtility; + const AttributeUtility& parent_; + + cel::AttributeSet attribute_set_; + cel::FunctionResultSet function_result_set_; + + // Some tests will use an empty unknown set as a sentinel. + // Preserve forwarding behavior. + bool unknown_present_; + }; + AttributeUtility( absl::Span unknown_patterns, absl::Span missing_attribute_patterns, @@ -35,9 +70,21 @@ class AttributeUtility { // attribute. bool CheckForMissingAttribute(const AttributeTrail& trail) const; - // Checks whether particular corresponds to any patterns that define unknowns. + // Checks whether trail corresponds to any patterns that define unknowns. bool CheckForUnknown(const AttributeTrail& trail, bool use_partial) const; + // Checks whether trail corresponds to any patterns that identify + // unknowns. Only matches exactly (exact attribute match for self or parent). + bool CheckForUnknownExact(const AttributeTrail& trail) const { + return CheckForUnknown(trail, false); + } + + // Checks whether trail corresponds to any patterns that define unknowns. + // Matches if a parent or any descendant (select or index of) the attribute. + bool CheckForUnknownPartial(const AttributeTrail& trail) const { + return CheckForUnknown(trail, true); + } + // Creates merged UnknownAttributeSet. // Scans over the args collection, determines if there matches to unknown // patterns and returns the (possibly empty) collection. @@ -50,6 +97,10 @@ class AttributeUtility { absl::optional MergeUnknowns( absl::Span args) const; + // Creates a merged UnknownValue from two unknown values. + cel::UnknownValue MergeUnknownValues(const cel::UnknownValue& left, + const cel::UnknownValue& right) const; + // Creates merged UnknownValue. // Merges together UnknownValues found in the args // along with attributes from attr that match the configured unknown patterns @@ -70,7 +121,17 @@ class AttributeUtility { const cel::FunctionDescriptor& fn_descriptor, int64_t expr_id, absl::Span args) const; + Accumulator CreateAccumulator() const ABSL_ATTRIBUTE_LIFETIME_BOUND { + return Accumulator(*this); + } + private: + cel::ValueManager& value_manager() const { return value_factory_; } + + // Workaround friend visibility. + void Add(Accumulator& a, const cel::UnknownValue& v) const; + void Add(Accumulator& a, const AttributeTrail& attr) const; + absl::Span unknown_patterns_; absl::Span missing_attribute_patterns_; cel::ValueManager& value_factory_; diff --git a/eval/eval/cel_expression_flat_impl.cc b/eval/eval/cel_expression_flat_impl.cc index 861fa0715..99e4ab488 100644 --- a/eval/eval/cel_expression_flat_impl.cc +++ b/eval/eval/cel_expression_flat_impl.cc @@ -16,16 +16,28 @@ #include #include +#include +#include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "common/native_type.h" #include "common/value.h" #include "common/value_manager.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/internal/adapter_activation_impl.h" #include "eval/internal/interop.h" +#include "eval/public/base_activation.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "extensions/protobuf/memory_manager.h" +#include "internal/casts.h" +#include "internal/status_macros.h" +#include "runtime/managed_value_factory.h" #include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -87,4 +99,42 @@ absl::StatusOr CelExpressionFlatImpl::Evaluate( return Trace(activation, state, CelEvaluationListener()); } +absl::StatusOr> +CelExpressionRecursiveImpl::Create(FlatExpression flat_expr) { + if (flat_expr.path().size() < 1 || + flat_expr.path().front()->GetNativeTypeId() != + cel::NativeTypeId::For()) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected a recursive program step", flat_expr.path().size())); + } + + auto* instance = new CelExpressionRecursiveImpl(std::move(flat_expr)); + + return absl::WrapUnique(instance); +} + +absl::StatusOr CelExpressionRecursiveImpl::Trace( + const BaseActivation& activation, google::protobuf::Arena* arena, + CelEvaluationListener callback) const { + cel::interop_internal::AdapterActivationImpl modern_activation(activation); + cel::ManagedValueFactory factory = flat_expression_.MakeValueFactory( + cel::extensions::ProtoMemoryManagerRef(arena)); + + ComprehensionSlots slots(flat_expression_.comprehension_slots_size()); + ExecutionFrameBase execution_frame(modern_activation, AdaptListener(callback), + flat_expression_.options(), factory.get(), + slots); + + cel::Value result; + AttributeTrail trail; + CEL_RETURN_IF_ERROR(root_->Evaluate(execution_frame, result, trail)); + + return cel::interop_internal::ModernValueToLegacyValueOrDie(arena, result); +} + +absl::StatusOr CelExpressionRecursiveImpl::Evaluate( + const BaseActivation& activation, google::protobuf::Arena* arena) const { + return Trace(activation, arena, /*callback=*/nullptr); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/cel_expression_flat_impl.h b/eval/eval/cel_expression_flat_impl.h index 801a14638..f22e4726e 100644 --- a/eval/eval/cel_expression_flat_impl.h +++ b/eval/eval/cel_expression_flat_impl.h @@ -18,9 +18,14 @@ #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/public/cel_expression.h" #include "extensions/protobuf/memory_manager.h" +#include "internal/casts.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -81,6 +86,76 @@ class CelExpressionFlatImpl : public CelExpression { FlatExpression flat_expression_; }; +// Implementation of the CelExpression that evaluates a recursive representation +// of the AST. +// +// This class adapts FlatExpression to implement the CelExpression interface. +// +// Assumes that the flat expression is wrapping a simple recursive program. +class CelExpressionRecursiveImpl : public CelExpression { + private: + class EvaluationState : public CelEvaluationState { + public: + explicit EvaluationState(google::protobuf::Arena* arena) : arena_(arena) {} + google::protobuf::Arena* arena() { return arena_; } + + private: + google::protobuf::Arena* arena_; + }; + + public: + static absl::StatusOr> Create( + FlatExpression flat_expression); + + // Move-only + CelExpressionRecursiveImpl(const CelExpressionRecursiveImpl&) = delete; + CelExpressionRecursiveImpl& operator=(const CelExpressionRecursiveImpl&) = + delete; + CelExpressionRecursiveImpl(CelExpressionRecursiveImpl&&) = default; + CelExpressionRecursiveImpl& operator=(CelExpressionRecursiveImpl&&) = default; + + // Implement CelExpression. + std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const override { + return std::make_unique(arena); + } + + absl::StatusOr Evaluate(const BaseActivation& activation, + google::protobuf::Arena* arena) const override; + + absl::StatusOr Evaluate(const BaseActivation& activation, + CelEvaluationState* state) const override { + auto* state_impl = cel::internal::down_cast(state); + return Evaluate(activation, state_impl->arena()); + } + + absl::StatusOr Trace(const BaseActivation& activation, + google::protobuf::Arena* arena, + CelEvaluationListener callback) const override; + + absl::StatusOr Trace( + const BaseActivation& activation, CelEvaluationState* state, + CelEvaluationListener callback) const override { + auto* state_impl = cel::internal::down_cast(state); + return Trace(activation, state_impl->arena(), callback); + } + + // Exposed for inspection in tests. + const FlatExpression& flat_expression() const { return flat_expression_; } + + const DirectExpressionStep* root() const { return root_; } + + private: + explicit CelExpressionRecursiveImpl(FlatExpression flat_expression) + : flat_expression_(std::move(flat_expression)), + root_(cel::internal::down_cast( + flat_expression_.path()[0].get()) + ->wrapped()) {} + + FlatExpression flat_expression_; + const DirectExpressionStep* root_; +}; + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_CEL_EXPRESSION_FLAT_IMPL_H_ diff --git a/eval/eval/compiler_constant_step.cc b/eval/eval/compiler_constant_step.cc index 9933dd06b..44a03cecd 100644 --- a/eval/eval/compiler_constant_step.cc +++ b/eval/eval/compiler_constant_step.cc @@ -13,8 +13,21 @@ // limitations under the License. #include "eval/eval/compiler_constant_step.h" +#include "absl/status/status.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/evaluator_core.h" + namespace google::api::expr::runtime { +using ::cel::Value; + +absl::Status DirectCompilerConstantStep::Evaluate( + ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { + result = value_; + return absl::OkStatus(); +} + absl::Status CompilerConstantStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().Push(value_); diff --git a/eval/eval/compiler_constant_step.h b/eval/eval/compiler_constant_step.h index 189cd1904..bd514a036 100644 --- a/eval/eval/compiler_constant_step.h +++ b/eval/eval/compiler_constant_step.h @@ -14,13 +14,41 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_COMPILER_CONSTANT_STEP_H_ +#include #include +#include "absl/status/status.h" #include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" namespace google::api::expr::runtime { +// DirectExpressionStep implementation that simply assigns a constant value. +// +// Overrides NativeTypeId() allow the FlatExprBuilder and extensions to +// inspect the underlying value. +class DirectCompilerConstantStep : public DirectExpressionStep { + public: + DirectCompilerConstantStep(cel::Value value, int64_t expr_id) + : DirectExpressionStep(expr_id), value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const override; + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + const cel::Value& value() const { return value_; } + + private: + cel::Value value_; +}; + // ExpressionStep implementation that simply pushes a constant value on the // stack. // diff --git a/eval/eval/compiler_constant_step_test.cc b/eval/eval/compiler_constant_step_test.cc index d602d65a4..8b733125a 100644 --- a/eval/eval/compiler_constant_step_test.cc +++ b/eval/eval/compiler_constant_step_test.cc @@ -59,7 +59,7 @@ TEST_F(CompilerConstantStepTest, Evaluate) { ExecutionFrame frame(path, empty_activation_, options_, state_); - ASSERT_OK_AND_ASSIGN(cel::Value result, frame.Evaluate(EvaluationListener())); + ASSERT_OK_AND_ASSIGN(cel::Value result, frame.Evaluate()); EXPECT_EQ(result->As().NativeValue(), 42); } diff --git a/eval/eval/comprehension_slots.h b/eval/eval/comprehension_slots.h index 244fa6e3a..70c2458a0 100644 --- a/eval/eval/comprehension_slots.h +++ b/eval/eval/comprehension_slots.h @@ -20,6 +20,7 @@ #include #include "absl/base/macros.h" +#include "absl/base/no_destructor.h" #include "absl/types/optional.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" @@ -40,6 +41,13 @@ class ComprehensionSlots { AttributeTrail attribute; }; + // Trivial instance if no slots are needed. + // Trivially thread safe since no effective state. + static ComprehensionSlots& GetEmptyInstance() { + static absl::NoDestructor instance(0); + return *instance; + } + explicit ComprehensionSlots(size_t size) : size_(size), slots_(size) {} // Move only @@ -67,6 +75,11 @@ class ComprehensionSlots { slots_[index] = absl::nullopt; } + void Set(size_t index) { + ABSL_ASSERT(index >= 0 && index < slots_.size()); + slots_[index].emplace(); + } + void Set(size_t index, cel::Value value) { Set(index, std::move(value), AttributeTrail()); } diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index f93eedde9..53b555265 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -5,31 +5,38 @@ #include #include +#include "absl/log/absl_check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "base/attribute.h" #include "base/kind.h" #include "common/casting.h" #include "common/value.h" +#include "common/value_kind.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" +#include "eval/public/cel_attribute.h" #include "internal/status_macros.h" #include "runtime/internal/mutable_list_impl.h" namespace google::api::expr::runtime { namespace { +using ::cel::BoolValue; using ::cel::Cast; using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::ListValue; +using ::cel::MapValue; using ::cel::UnknownValue; using ::cel::Value; +using ::cel::ValueView; using ::cel::runtime_internal::CreateNoMatchingOverloadError; using ::cel::runtime_internal::MutableListValue; @@ -79,25 +86,27 @@ class ComprehensionInitStep : public ExpressionStepBase { absl::Status ProjectKeys(ExecutionFrame* frame) const; }; -absl::Status ComprehensionInitStep::ProjectKeys(ExecutionFrame* frame) const { +absl::StatusOr ProjectKeysImpl(ExecutionFrameBase& frame, + const MapValue& range, + const AttributeTrail& trail) { // Top of stack is map, but could be partially unknown. To tolerate cases when // keys are not set for declared unknown values, convert to an unknown set. - if (frame->enable_unknowns()) { - absl::optional unknown = - frame->attribute_utility().IdentifyAndMergeUnknowns( - frame->value_stack().GetSpan(1), - frame->value_stack().GetAttributeSpan(1), - /*use_partial=*/true); - if (unknown.has_value()) { - frame->value_stack().PopAndPush(*std::move(unknown)); - return absl::OkStatus(); + if (frame.unknown_processing_enabled()) { + if (frame.attribute_utility().CheckForUnknownPartial(trail)) { + return frame.attribute_utility().CreateUnknownSet(trail.attribute()); } } - CEL_ASSIGN_OR_RETURN(auto list_keys, - frame->value_stack().Peek().As().ListKeys( - frame->value_factory())); - frame->value_stack().PopAndPush(std::move(list_keys)); + return range.ListKeys(frame.value_manager()); +} + +absl::Status ComprehensionInitStep::ProjectKeys(ExecutionFrame* frame) const { + auto map_value = Cast(frame->value_stack().Peek()); + CEL_ASSIGN_OR_RETURN( + Value keys, + ProjectKeysImpl(*frame, map_value, frame->value_stack().PeekAttribute())); + + frame->value_stack().PopAndPush(std::move(keys)); return absl::OkStatus(); } @@ -126,6 +135,148 @@ absl::Status ComprehensionInitStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class ComprehensionDirectStep : public DirectExpressionStep { + public: + explicit ComprehensionDirectStep( + size_t iter_slot, size_t accu_slot, + std::unique_ptr range, + std::unique_ptr accu_init, + std::unique_ptr loop_step, + std::unique_ptr condition_step, + std::unique_ptr result_step, bool shortcircuiting, + int64_t expr_id) + : DirectExpressionStep(expr_id), + iter_slot_(iter_slot), + accu_slot_(accu_slot), + range_(std::move(range)), + accu_init_(std::move(accu_init)), + loop_step_(std::move(loop_step)), + condition_(std::move(condition_step)), + result_step_(std::move(result_step)), + shortcircuiting_(shortcircuiting) {} + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const override; + + private: + size_t iter_slot_; + size_t accu_slot_; + std::unique_ptr range_; + std::unique_ptr accu_init_; + std::unique_ptr loop_step_; + std::unique_ptr condition_; + std::unique_ptr result_step_; + + bool shortcircuiting_; +}; + +absl::Status ComprehensionDirectStep::Evaluate(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + cel::Value range; + AttributeTrail range_attr; + CEL_RETURN_IF_ERROR(range_->Evaluate(frame, range, range_attr)); + + if (InstanceOf(range)) { + auto map_value = Cast(range); + CEL_ASSIGN_OR_RETURN(range, ProjectKeysImpl(frame, map_value, range_attr)); + } + + switch (range.kind()) { + case cel::ValueKind::kError: + case cel::ValueKind::kUnknown: + result = range; + return absl::OkStatus(); + break; + default: + if (!InstanceOf(range)) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + return absl::OkStatus(); + } + } + + auto range_list = Cast(range); + + Value accu_init; + AttributeTrail accu_init_attr; + CEL_RETURN_IF_ERROR(accu_init_->Evaluate(frame, accu_init, accu_init_attr)); + + frame.comprehension_slots().Set(accu_slot_, std::move(accu_init), + accu_init_attr); + ComprehensionSlots::Slot* accu_slot = + frame.comprehension_slots().Get(accu_slot_); + ABSL_DCHECK(accu_slot != nullptr); + + frame.comprehension_slots().Set(iter_slot_); + ComprehensionSlots::Slot* iter_slot = + frame.comprehension_slots().Get(iter_slot_); + ABSL_DCHECK(iter_slot != nullptr); + + Value condition; + AttributeTrail condition_attr; + bool should_skip_result = false; + CEL_RETURN_IF_ERROR(range_list.ForEach( + frame.value_manager(), + [&](size_t index, ValueView v) -> absl::StatusOr { + CEL_RETURN_IF_ERROR(frame.IncrementIterations()); + // Evaluate loop condition first. + CEL_RETURN_IF_ERROR( + condition_->Evaluate(frame, condition, condition_attr)); + + if (condition.kind() == cel::ValueKind::kError || + condition.kind() == cel::ValueKind::kUnknown) { + result = std::move(condition); + should_skip_result = true; + return false; + } + if (condition.kind() != cel::ValueKind::kBool) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError("")); + should_skip_result = true; + return false; + } + if (shortcircuiting_ && !Cast(condition).NativeValue()) { + return false; + } + + iter_slot->value = v; + if (frame.unknown_processing_enabled()) { + iter_slot->attribute = + range_attr.Step(CelAttributeQualifier::OfInt(index)); + if (frame.attribute_utility().CheckForUnknownExact( + iter_slot->attribute)) { + iter_slot->value = frame.attribute_utility().CreateUnknownSet( + iter_slot->attribute.attribute()); + } + } + + CEL_RETURN_IF_ERROR(loop_step_->Evaluate(frame, accu_slot->value, + accu_slot->attribute)); + + return true; + })); + + frame.comprehension_slots().ClearSlot(iter_slot_); + // Error state is already set to the return value, just clean up. + if (should_skip_result) { + frame.comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(result_step_->Evaluate(frame, result, trail)); + if (frame.options().enable_comprehension_list_append && + MutableListValue::Is(result)) { + // We assume the list builder is 'owned' by the evaluator stack so + // destructive operation is safe here. + // + // Convert the buildable list to an actual cel::ListValue. + MutableListValue& list_value = MutableListValue::Cast(result); + CEL_ASSIGN_OR_RETURN(result, std::move(list_value).Build()); + } + frame.comprehension_slots().ClearSlot(accu_slot_); + return absl::OkStatus(); +} + } // namespace // Stack variables during comprehension evaluation: @@ -307,6 +458,20 @@ absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +std::unique_ptr CreateDirectComprehensionStep( + size_t iter_slot, size_t accu_slot, + std::unique_ptr range, + std::unique_ptr accu_init, + std::unique_ptr loop_step, + std::unique_ptr condition_step, + std::unique_ptr result_step, bool shortcircuiting, + int64_t expr_id) { + return std::make_unique( + iter_slot, accu_slot, std::move(range), std::move(accu_init), + std::move(loop_step), std::move(condition_step), std::move(result_step), + shortcircuiting, expr_id); +} + std::unique_ptr CreateComprehensionFinishStep(size_t accu_slot, int64_t expr_id) { return std::make_unique(accu_slot, expr_id); diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index a95e271b3..c0fc78aa0 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -6,6 +6,7 @@ #include #include "absl/status/status.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" @@ -45,6 +46,16 @@ class ComprehensionCondStep : public ExpressionStepBase { bool shortcircuiting_; }; +// Creates a step for executing a comprehension. +std::unique_ptr CreateDirectComprehensionStep( + size_t iter_slot, size_t accu_slot, + std::unique_ptr range, + std::unique_ptr accu_init, + std::unique_ptr loop_step, + std::unique_ptr condition_step, + std::unique_ptr result_step, bool shortcircuiting, + int64_t expr_id); + // Creates a cleanup step for the comprehension. // Removes the comprehension context then pushes the 'result' sub expression to // the top of the stack. diff --git a/eval/eval/comprehension_step_test.cc b/eval/eval/comprehension_step_test.cc index 978450d9c..e92c5bd04 100644 --- a/eval/eval/comprehension_step_test.cc +++ b/eval/eval/comprehension_step_test.cc @@ -1,6 +1,5 @@ #include "eval/eval/comprehension_step.h" -#include #include #include #include @@ -10,27 +9,47 @@ #include "google/protobuf/struct.pb.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "base/ast_internal/expr.h" #include "base/type_provider.h" +#include "common/value.h" +#include "common/value_testing.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/comprehension_slots.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" +#include "extensions/protobuf/memory_manager.h" +#include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" +#include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { +using ::cel::BoolValue; +using ::cel::IntValue; using ::cel::TypeProvider; +using ::cel::Value; using ::cel::ast_internal::Expr; using ::cel::ast_internal::Ident; +using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::test::BoolValueIs; using ::google::protobuf::ListValue; using ::google::protobuf::Struct; using ::google::protobuf::Arena; +using testing::_; using testing::Eq; +using testing::Return; using testing::SizeIs; +using cel::internal::StatusIs; Ident CreateIdent(const std::string& var) { Ident expr; @@ -230,5 +249,283 @@ TEST_F(ListKeysStepTest, UnknownSetPassedThrough) { EXPECT_THAT(eval_result->UnknownSetOrDie()->unknown_attributes(), SizeIs(1)); } +class MockDirectStep : public DirectExpressionStep { + public: + MockDirectStep() : DirectExpressionStep(-1) {} + + MOCK_METHOD(absl::Status, Evaluate, + (ExecutionFrameBase&, Value&, AttributeTrail&), (const override)); +}; + +// Test fixture for comprehensions. +// +// Comprehensions are quite involved so tests here focus on edge cases that are +// hard to exercise normally in functional-style tests for the planner. +class DirectComprehensionTest : public testing::Test { + public: + DirectComprehensionTest() + : value_manager_(TypeProvider::Builtin(), ProtoMemoryManagerRef(&arena_)), + slots_(2) {} + + // returns a two element list for testing [1, 2]. + absl::StatusOr MakeList() { + CEL_ASSIGN_OR_RETURN(auto builder, + value_manager_.get().NewListValueBuilder( + value_manager_.get().GetDynListType())); + + CEL_RETURN_IF_ERROR(builder->Add(IntValue(1))); + CEL_RETURN_IF_ERROR(builder->Add(IntValue(2))); + return std::move(*builder).Build(); + } + + protected: + google::protobuf::Arena arena_; + cel::ManagedValueFactory value_manager_; + ComprehensionSlots slots_; + cel::Activation empty_activation_; +}; + +TEST_F(DirectComprehensionTest, PropagateRangeNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto range_step = std::make_unique(); + MockDirectStep* mock = range_step.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test range error"))); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/std::move(range_step), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test range error")); +} + +TEST_F(DirectComprehensionTest, PropagateAccuInitNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto accu_init = std::make_unique(); + MockDirectStep* mock = accu_init.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test accu init error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/std::move(accu_init), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test accu init error")); +} + +TEST_F(DirectComprehensionTest, PropagateLoopNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test loop error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test loop error")); +} + +TEST_F(DirectComprehensionTest, PropagateConditionNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto condition = std::make_unique(); + MockDirectStep* mock = condition.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test condition error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/std::move(condition), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test condition error")); +} + +TEST_F(DirectComprehensionTest, PropagateResultNonOkStatus) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto result_step = std::make_unique(); + MockDirectStep* mock = result_step.get(); + + ON_CALL(*mock, Evaluate(_, _, _)) + .WillByDefault(Return(absl::InternalError("test result error"))); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/std::move(result_step), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal, "test result error")); +} + +TEST_F(DirectComprehensionTest, Shortcircuit) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + EXPECT_CALL(*mock, Evaluate(_, _, _)) + .Times(0) + .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { + result = BoolValue(false); + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + ASSERT_OK(compre_step->Evaluate(frame, result, trail)); + EXPECT_THAT(result, BoolValueIs(false)); +} + +TEST_F(DirectComprehensionTest, IterationLimit) { + cel::RuntimeOptions options; + options.comprehension_max_iterations = 2; + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + EXPECT_CALL(*mock, Evaluate(_, _, _)) + .Times(1) + .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { + result = BoolValue(false); + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(true)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/true, -1); + + Value result; + AttributeTrail trail; + EXPECT_THAT(compre_step->Evaluate(frame, result, trail), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST_F(DirectComprehensionTest, Exhaustive) { + cel::RuntimeOptions options; + + ExecutionFrameBase frame(empty_activation_, /*callback=*/nullptr, options, + value_manager_.get(), slots_); + + auto loop_step = std::make_unique(); + MockDirectStep* mock = loop_step.get(); + + EXPECT_CALL(*mock, Evaluate(_, _, _)) + .Times(2) + .WillRepeatedly([](ExecutionFrameBase&, Value& result, AttributeTrail&) { + result = BoolValue(false); + return absl::OkStatus(); + }); + + ASSERT_OK_AND_ASSIGN(auto list, MakeList()); + + auto compre_step = CreateDirectComprehensionStep( + 0, 1, + /*range_step=*/CreateConstValueDirectStep(std::move(list)), + /*accu_init=*/CreateConstValueDirectStep(BoolValue(false)), + /*loop_step=*/std::move(loop_step), + /*condition_step=*/CreateConstValueDirectStep(BoolValue(false)), + /*result_step=*/CreateDirectSlotIdentStep("__result__", 1, -1), + /*shortcircuiting=*/false, -1); + + Value result; + AttributeTrail trail; + ASSERT_OK(compre_step->Evaluate(frame, result, trail)); + EXPECT_THAT(result, BoolValueIs(false)); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 7f1d2d006..53ed03faa 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -9,15 +9,25 @@ #include "common/value.h" #include "common/value_manager.h" #include "eval/eval/compiler_constant_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "internal/status_macros.h" #include "runtime/internal/convert_constant.h" namespace google::api::expr::runtime { +namespace { + using ::cel::ast_internal::Constant; using ::cel::runtime_internal::ConvertConstant; +} // namespace + +std::unique_ptr CreateConstValueDirectStep( + cel::Value value, int64_t id) { + return std::make_unique(std::move(value), id); +} + absl::StatusOr> CreateConstValueStep( cel::Value value, int64_t expr_id, bool comes_from_ast) { return std::make_unique(std::move(value), expr_id, diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index 396adcd6f..f3a95a6cb 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -8,10 +8,14 @@ #include "base/ast_internal/expr.h" #include "common/value.h" #include "common/value_manager.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +std::unique_ptr CreateConstValueDirectStep( + cel::Value value, int64_t expr_id = -1); + // Factory method for Constant Value expression step. absl::StatusOr> CreateConstValueStep( cel::Value value, int64_t expr_id, bool comes_from_ast = true); diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 19893c62c..e88f6be9e 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -10,18 +10,23 @@ #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "base/ast_internal/expr.h" #include "base/attribute.h" #include "base/kind.h" #include "common/casting.h" #include "common/native_type.h" #include "common/value.h" #include "common/value_kind.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/attribute_utility.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" #include "internal/casts.h" #include "internal/number.h" #include "internal/status_macros.h" +#include "runtime/internal/errors.h" namespace google::api::expr::runtime { @@ -29,8 +34,10 @@ namespace { using ::cel::AttributeQualifier; using ::cel::BoolValue; +using ::cel::Cast; using ::cel::DoubleValue; - +using ::cel::ErrorValue; +using ::cel::InstanceOf; using ::cel::IntValue; using ::cel::ListValue; using ::cel::MapValue; @@ -45,36 +52,6 @@ using ::cel::runtime_internal::CreateNoSuchKeyError; inline constexpr int kNumContainerAccessArguments = 2; -// ContainerAccessStep performs message field access specified by Expr::Select -// message. -class ContainerAccessStep : public ExpressionStepBase { - public: - ContainerAccessStep(int64_t expr_id, bool enable_optional_types) - : ExpressionStepBase(expr_id), - enable_optional_types_(enable_optional_types) {} - - absl::Status Evaluate(ExecutionFrame* frame) const override; - - private: - struct LookupResult { - ValueView value; - AttributeTrail trail; - }; - - LookupResult PerformLookup(ExecutionFrame* frame, Value& scratch) const; - absl::StatusOr LookupInMap(const MapValue& cel_map, - const Value& key, ExecutionFrame* frame, - Value& scratch) const; - absl::StatusOr LookupInList(const ListValue& cel_list, - const Value& key, - ExecutionFrame* frame, - Value& scratch) const; - absl::StatusOr Lookup(const Value& container, const Value& key, - ExecutionFrame* frame, Value& scratch) const; - - const bool enable_optional_types_; -}; - absl::optional CelNumberFromValue(const Value& value) { switch (value->kind()) { case ValueKind::kInt64: @@ -118,165 +95,164 @@ AttributeQualifier AttributeQualifierFromValue(const Value& v) { } } -absl::StatusOr ContainerAccessStep::LookupInMap( - const MapValue& cel_map, const Value& key, ExecutionFrame* frame, - Value& scratch) const { - if (frame->enable_heterogeneous_numeric_lookups()) { +ValueView LookupInMap(const MapValue& cel_map, const Value& key, + ExecutionFrameBase& frame, Value& scratch) { + if (frame.options().enable_heterogeneous_equality) { // Double isn't a supported key type but may be convertible to an integer. absl::optional number = CelNumberFromValue(key); if (number.has_value()) { // Consider uint as uint first then try coercion (prefer matching the // original type of the key value). if (key->Is()) { - ValueView value; - bool ok; - CEL_ASSIGN_OR_RETURN( - std::tie(value, ok), - cel_map.Find(frame->value_factory(), key, scratch)); - if (ok) { - return value; + auto lookup = cel_map.Find(frame.value_manager(), key, scratch); + if (!lookup.ok()) { + scratch = frame.value_manager().CreateErrorValue( + std::move(lookup).status()); + return ValueView{scratch}; + } + if (lookup->second) { + return lookup->first; } } // double / int / uint -> int if (number->LosslessConvertibleToInt()) { - ValueView value; - bool ok; - CEL_ASSIGN_OR_RETURN( - std::tie(value, ok), - cel_map.Find(frame->value_factory(), - frame->value_factory().CreateIntValue(number->AsInt()), - scratch)); - if (ok) { - return value; + auto lookup = cel_map.Find( + frame.value_manager(), + frame.value_manager().CreateIntValue(number->AsInt()), scratch); + if (!lookup.ok()) { + scratch = frame.value_manager().CreateErrorValue( + std::move(lookup).status()); + return ValueView{scratch}; + } + if (lookup->second) { + return lookup->first; } } // double / int -> uint if (number->LosslessConvertibleToUint()) { - ValueView value; - bool ok; - CEL_ASSIGN_OR_RETURN( - std::tie(value, ok), - cel_map.Find( - frame->value_factory(), - frame->value_factory().CreateUintValue(number->AsUint()), - scratch)); - if (ok) { - return value; + auto lookup = cel_map.Find( + frame.value_manager(), + frame.value_manager().CreateUintValue(number->AsUint()), scratch); + if (!lookup.ok()) { + scratch = frame.value_manager().CreateErrorValue( + std::move(lookup).status()); + return ValueView{scratch}; + } + if (lookup->second) { + return lookup->first; } } - scratch = frame->value_factory().CreateErrorValue( + scratch = frame.value_manager().CreateErrorValue( CreateNoSuchKeyError(key->DebugString())); - return scratch; + return ValueView{scratch}; } } - CEL_RETURN_IF_ERROR(CheckMapKeyType(key)); + absl::Status status = CheckMapKeyType(key); + if (!status.ok()) { + scratch = frame.value_manager().CreateErrorValue(std::move(status)); + return ValueView{scratch}; + } - return cel_map.Get(frame->value_factory(), key, scratch); + absl::StatusOr lookup = + cel_map.Get(frame.value_manager(), key, scratch); + if (!lookup.ok()) { + scratch = + frame.value_manager().CreateErrorValue(std::move(lookup).status()); + return ValueView{scratch}; + } + return *lookup; } -absl::StatusOr ContainerAccessStep::LookupInList( - const ListValue& cel_list, const Value& key, ExecutionFrame* frame, - Value& scratch) const { +ValueView LookupInList(const ListValue& cel_list, const Value& key, + ExecutionFrameBase& frame, Value& scratch) { absl::optional maybe_idx; - if (frame->enable_heterogeneous_numeric_lookups()) { + if (frame.options().enable_heterogeneous_equality) { auto number = CelNumberFromValue(key); if (number.has_value() && number->LosslessConvertibleToInt()) { maybe_idx = number->AsInt(); } - } else if (key->Is()) { + } else if (InstanceOf(key)) { maybe_idx = key.As().NativeValue(); } - if (maybe_idx.has_value()) { - int64_t idx = *maybe_idx; - if (idx < 0 || idx >= cel_list.Size()) { - return absl::UnknownError( - absl::StrCat("Index error: index=", idx, " size=", cel_list.Size())); - } - return cel_list.Get(frame->value_factory(), idx, scratch); + if (!maybe_idx.has_value()) { + scratch = frame.value_manager().CreateErrorValue(absl::UnknownError( + absl::StrCat("Index error: expected integer type, got ", + cel::KindToString(ValueKindToKind(key->kind()))))); + return ValueView{scratch}; } - return absl::UnknownError( - absl::StrCat("Index error: expected integer type, got ", - cel::KindToString(ValueKindToKind(key->kind())))); + int64_t idx = *maybe_idx; + if (idx < 0 || idx >= cel_list.Size()) { + scratch = frame.value_manager().CreateErrorValue(absl::UnknownError( + absl::StrCat("Index error: index=", idx, " size=", cel_list.Size()))); + return ValueView{scratch}; + } + + absl::StatusOr lookup = + cel_list.Get(frame.value_manager(), idx, scratch); + + if (!lookup.ok()) { + scratch = + frame.value_manager().CreateErrorValue(std::move(lookup).status()); + return ValueView{scratch}; + } + return *lookup; } -absl::StatusOr ContainerAccessStep::Lookup(const Value& container, - const Value& key, - ExecutionFrame* frame, - Value& scratch) const { +ValueView LookupInContainer(const Value& container, const Value& key, + ExecutionFrameBase& frame, Value& scratch) { // Select steps can be applied to either maps or messages - switch (container->kind()) { + switch (container.kind()) { case ValueKind::kMap: { - auto result = LookupInMap(container.As(), key, frame, scratch); - if (!result.ok()) { - scratch = - frame->value_factory().CreateErrorValue(std::move(result).status()); - return ValueView{scratch}; - } - return *result; + return LookupInMap(Cast(container), key, frame, scratch); } case ValueKind::kList: { - auto result = - LookupInList(container.As(), key, frame, scratch); - if (!result.ok()) { - scratch = - frame->value_factory().CreateErrorValue(std::move(result).status()); - return ValueView{scratch}; - } - return *result; + return LookupInList(Cast(container), key, frame, scratch); } default: scratch = - frame->value_factory().CreateErrorValue(absl::InvalidArgumentError( + frame.value_manager().CreateErrorValue(absl::InvalidArgumentError( absl::StrCat("Invalid container type: '", ValueKindToString(container->kind()), "'"))); return ValueView{scratch}; } } -ContainerAccessStep::LookupResult ContainerAccessStep::PerformLookup( - ExecutionFrame* frame, Value& scratch) const { - auto input_args = frame->value_stack().GetSpan(kNumContainerAccessArguments); - AttributeTrail trail; - - const Value& container = input_args[0]; - const Value& key = input_args[1]; - - if (frame->enable_unknowns()) { - auto unknown_set = frame->attribute_utility().MergeUnknowns(input_args); - - if (unknown_set) { - scratch = std::move(unknown_set).value(); - return {ValueView{scratch}, std::move(trail)}; +ValueView PerformLookup(ExecutionFrameBase& frame, const Value& container, + const Value& key, const AttributeTrail& container_trail, + bool enable_optional_types, Value& scratch, + AttributeTrail& trail) { + if (frame.unknown_processing_enabled()) { + AttributeUtility::Accumulator unknowns = + frame.attribute_utility().CreateAccumulator(); + unknowns.MaybeAdd(container); + unknowns.MaybeAdd(key); + + if (!unknowns.IsEmpty()) { + scratch = std::move(unknowns).Build(); + return ValueView{scratch}; } - // We guarantee that GetAttributeSpan can acquire this number of arguments - // by calling HasEnough() at the beginning of Execute() method. - absl::Span input_attrs = - frame->value_stack().GetAttributeSpan(kNumContainerAccessArguments); - const auto& container_trail = input_attrs[0]; trail = container_trail.Step(AttributeQualifierFromValue(key)); - if (frame->attribute_utility().CheckForUnknown(trail, - /*use_partial=*/false)) { - cel::Attribute attr = trail.attribute(); - scratch = frame->attribute_utility().CreateUnknownSet(attr); - return {ValueView{scratch}, std::move(trail)}; + if (frame.attribute_utility().CheckForUnknownExact(trail)) { + scratch = frame.attribute_utility().CreateUnknownSet(trail.attribute()); + return ValueView{scratch}; } } - if (container.Is()) { + if (InstanceOf(container)) { scratch = container; - return {ValueView{scratch}, std::move(trail)}; + return ValueView{scratch}; } - if (key.Is()) { + if (InstanceOf(key)) { scratch = key; - return {ValueView{scratch}, std::move(trail)}; + return ValueView{scratch}; } - if (enable_optional_types_ && + if (enable_optional_types && cel::NativeTypeId::Of(container) == cel::NativeTypeId::For()) { const auto& optional_value = @@ -284,32 +260,37 @@ ContainerAccessStep::LookupResult ContainerAccessStep::PerformLookup( cel::Cast(container).operator->()); if (!optional_value.HasValue()) { scratch = cel::OptionalValue::None(); - return {ValueView{scratch}, std::move(trail)}; - } - auto result = Lookup(optional_value.Value(), key, frame, scratch); - if (!result.ok()) { - scratch = - frame->value_factory().CreateErrorValue(std::move(result).status()); - return {ValueView{scratch}, std::move(trail)}; + return ValueView{scratch}; } - if (auto error_value = cel::As(*result); + auto result = + LookupInContainer(optional_value.Value(), key, frame, scratch); + if (auto error_value = cel::As(result); error_value && cel::IsNoSuchKey(error_value->NativeValue())) { scratch = cel::OptionalValue::None(); - return {ValueView{scratch}, std::move(trail)}; + return ValueView{scratch}; } - scratch = cel::OptionalValue::Of(frame->memory_manager(), Value{*result}); - return {ValueView{scratch}, std::move(trail)}; + scratch = cel::OptionalValue::Of(frame.value_manager().GetMemoryManager(), + Value{result}); + return ValueView{scratch}; } - auto result = Lookup(container, key, frame, scratch); - if (!result.ok()) { - scratch = - frame->value_factory().CreateErrorValue(std::move(result).status()); - return {ValueView{scratch}, std::move(trail)}; - } - return {*result, std::move(trail)}; + return LookupInContainer(container, key, frame, scratch); } +// ContainerAccessStep performs message field access specified by Expr::Select +// message. +class ContainerAccessStep : public ExpressionStepBase { + public: + explicit ContainerAccessStep(bool enable_optional_types, int64_t expr_id) + : ExpressionStepBase(expr_id), + enable_optional_types_(enable_optional_types) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + bool enable_optional_types_; +}; + absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(kNumContainerAccessArguments)) { return absl::Status( @@ -318,14 +299,68 @@ absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { } Value scratch; - auto result = PerformLookup(frame, scratch); - frame->value_stack().PopAndPush(kNumContainerAccessArguments, - Value{result.value}, std::move(result.trail)); + AttributeTrail result_trail; + auto args = frame->value_stack().GetSpan(kNumContainerAccessArguments); + const AttributeTrail& container_trail = + frame->value_stack().GetAttributeSpan(kNumContainerAccessArguments)[0]; + + auto result = PerformLookup(*frame, args[0], args[1], container_trail, + enable_optional_types_, scratch, result_trail); + frame->value_stack().PopAndPush(kNumContainerAccessArguments, Value{result}, + std::move(result_trail)); return absl::OkStatus(); } + +class DirectContainerAccessStep : public DirectExpressionStep { + public: + DirectContainerAccessStep( + std::unique_ptr container_step, + std::unique_ptr key_step, + bool enable_optional_types, int64_t expr_id) + : DirectExpressionStep(expr_id), + container_step_(std::move(container_step)), + key_step_(std::move(key_step)), + enable_optional_types_(enable_optional_types) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const override; + + private: + std::unique_ptr container_step_; + std::unique_ptr key_step_; + bool enable_optional_types_; +}; + +absl::Status DirectContainerAccessStep::Evaluate(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + Value container; + Value key; + AttributeTrail container_trail; + AttributeTrail key_trail; + + CEL_RETURN_IF_ERROR( + container_step_->Evaluate(frame, container, container_trail)); + CEL_RETURN_IF_ERROR(key_step_->Evaluate(frame, key, key_trail)); + + result = PerformLookup(frame, container, key, container_trail, + enable_optional_types_, result, trail); + + return absl::OkStatus(); +} + } // namespace +std::unique_ptr CreateDirectContainerAccessStep( + std::unique_ptr container_step, + std::unique_ptr key_step, bool enable_optional_types, + int64_t expr_id) { + return std::make_unique( + std::move(container_step), std::move(key_step), enable_optional_types, + expr_id); +} + // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( const cel::ast_internal::Call& call, int64_t expr_id, diff --git a/eval/eval/container_access_step.h b/eval/eval/container_access_step.h index 2c6f5a600..05bd76f0c 100644 --- a/eval/eval/container_access_step.h +++ b/eval/eval/container_access_step.h @@ -6,10 +6,16 @@ #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +std::unique_ptr CreateDirectContainerAccessStep( + std::unique_ptr container_step, + std::unique_ptr key_step, bool enable_optional_types, + int64_t expr_id); + // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( const cel::ast_internal::Call& call, int64_t expr_id, diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 880b0faec..a2e5c11d5 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -8,11 +8,11 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/struct.pb.h" -#include "google/protobuf/arena.h" #include "absl/status/status.h" #include "base/builtins.h" #include "base/type_provider.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" @@ -25,8 +25,10 @@ #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" +#include "eval/public/unknown_set.h" #include "internal/testing.h" #include "parser/parser.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -42,12 +44,12 @@ using testing::AllOf; using testing::HasSubstr; using cel::internal::StatusIs; -using TestParamType = std::tuple; +using TestParamType = std::tuple; -// Helper method. Looks up in registry and tests comparison operation. CelValue EvaluateAttributeHelper( - google::protobuf::Arena* arena, CelValue container, CelValue key, bool receiver_style, - bool enable_unknown, const std::vector& patterns) { + google::protobuf::Arena* arena, CelValue container, CelValue key, + bool use_recursive_impl, bool receiver_style, bool enable_unknown, + const std::vector& patterns) { ExecutionPath path; Expr expr; @@ -64,10 +66,19 @@ CelValue EvaluateAttributeHelper( container_expr.mutable_ident_expr().set_name("container"); key_expr.mutable_ident_expr().set_name("key"); - path.push_back( - std::move(CreateIdentStep(container_expr.ident_expr(), 1).value())); - path.push_back(std::move(CreateIdentStep(key_expr.ident_expr(), 2).value())); - path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); + if (use_recursive_impl) { + path.push_back(std::make_unique( + CreateDirectContainerAccessStep(CreateDirectIdentStep("container", 1), + CreateDirectIdentStep("key", 2), + /*enable_optional_types=*/false, 3), + 3)); + } else { + path.push_back( + std::move(CreateIdentStep(container_expr.ident_expr(), 1).value())); + path.push_back( + std::move(CreateIdentStep(key_expr.ident_expr(), 2).value())); + path.push_back(std::move(CreateContainerAccessStep(call, 3).value())); + } cel::RuntimeOptions options; options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; @@ -87,16 +98,17 @@ CelValue EvaluateAttributeHelper( class ContainerAccessStepTest : public ::testing::Test { protected: - ContainerAccessStepTest() {} + ContainerAccessStepTest() = default; void SetUp() override {} CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, - bool enable_unknown, + bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { return EvaluateAttributeHelper(&arena_, container, key, receiver_style, - enable_unknown, patterns); + enable_unknown, use_recursive_impl, + patterns); } google::protobuf::Arena arena_; }; @@ -104,7 +116,7 @@ class ContainerAccessStepTest : public ::testing::Test { class ContainerAccessStepUniformityTest : public ::testing::TestWithParam { protected: - ContainerAccessStepUniformityTest() {} + ContainerAccessStepUniformityTest() = default; void SetUp() override {} @@ -118,13 +130,19 @@ class ContainerAccessStepUniformityTest return std::get<1>(params); } + bool use_recursive_impl() { + TestParamType params = GetParam(); + return std::get<2>(params); + } + // Helper method. Looks up in registry and tests comparison operation. CelValue EvaluateAttribute( CelValue container, CelValue key, bool receiver_style, - bool enable_unknown, + bool enable_unknown, bool use_recursive_impl = false, const std::vector& patterns = {}) { return EvaluateAttributeHelper(&arena_, container, key, receiver_style, - enable_unknown, patterns); + enable_unknown, use_recursive_impl, + patterns); } google::protobuf::Arena arena_; }; @@ -278,8 +296,9 @@ TEST_F(ContainerAccessStepTest, TestListIndexAccessUnknown) { "container", {CreateCelAttributeQualifierPattern(CelValue::CreateInt64(1))})}; - result = EvaluateAttribute(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(1), true, true, patterns); + result = + EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true, true, false, patterns); ASSERT_TRUE(result.IsUnknownSet()); } @@ -351,10 +370,11 @@ TEST_F(ContainerAccessStepTest, TestInvalidContainerType) { HasSubstr("Invalid container type: 'int"))); } -INSTANTIATE_TEST_SUITE_P(CombinedContainerTest, - ContainerAccessStepUniformityTest, - testing::Combine(/*receiver_style*/ testing::Bool(), - /*unknown_enabled*/ testing::Bool())); +INSTANTIATE_TEST_SUITE_P( + CombinedContainerTest, ContainerAccessStepUniformityTest, + testing::Combine(/*receiver_style*/ testing::Bool(), + /*unknown_enabled*/ testing::Bool(), + /*use_recursive_impl*/ testing::Bool())); class ContainerAccessHeterogeneousLookupsTest : public testing::Test { public: diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index a5a5ac150..c842d7969 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -4,13 +4,19 @@ #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" +#include "base/ast_internal/expr.h" #include "common/casting.h" #include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/attribute_utility.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" #include "runtime/internal/mutable_list_impl.h" @@ -19,24 +25,26 @@ namespace google::api::expr::runtime { namespace { +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; using ::cel::ListValueBuilderInterface; using ::cel::UnknownValue; +using ::cel::Value; using ::cel::runtime_internal::MutableListValue; class CreateListStep : public ExpressionStepBase { public: - CreateListStep(int64_t expr_id, int list_size, bool immutable, + CreateListStep(int64_t expr_id, int list_size, absl::flat_hash_set optional_indices) : ExpressionStepBase(expr_id), list_size_(list_size), - immutable_(immutable), optional_indices_(std::move(optional_indices)) {} absl::Status Evaluate(ExecutionFrame* frame) const override; private: int list_size_; - bool immutable_; absl::flat_hash_set optional_indices_; }; @@ -97,14 +105,7 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { } } - if (immutable_) { - result = std::move(*builder).Build(); - } else { - result = cel::OpaqueValue{ - frame->value_manager().GetMemoryManager().MakeShared( - std::move(builder))}; - } - frame->value_stack().PopAndPush(list_size_, std::move(result)); + frame->value_stack().PopAndPush(list_size_, std::move(*builder).Build()); return absl::OkStatus(); } @@ -117,20 +118,148 @@ absl::flat_hash_set MakeOptionalIndicesSet( return optional_indices; } +class CreateListDirectStep : public DirectExpressionStep { + public: + CreateListDirectStep( + std::vector> elements, + absl::flat_hash_set optional_indices, int64_t expr_id) + : DirectExpressionStep(expr_id), + elements_(std::move(elements)), + optional_indices_(std::move(optional_indices)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override { + CEL_ASSIGN_OR_RETURN(auto builder, + frame.value_manager().NewListValueBuilder( + frame.value_manager().GetDynListType())); + + builder->Reserve(elements_.size()); + AttributeUtility::Accumulator unknowns = + frame.attribute_utility().CreateAccumulator(); + AttributeTrail tmp_attr; + + for (size_t i = 0; i < elements_.size(); ++i) { + const auto& element = elements_[i]; + CEL_RETURN_IF_ERROR(element->Evaluate(frame, result, tmp_attr)); + + if (cel::InstanceOf(result)) return absl::OkStatus(); + + if (frame.attribute_tracking_enabled()) { + if (frame.missing_attribute_errors_enabled()) { + if (frame.attribute_utility().CheckForMissingAttribute(tmp_attr)) { + CEL_ASSIGN_OR_RETURN( + result, frame.attribute_utility().CreateMissingAttributeError( + tmp_attr.attribute())); + return absl::OkStatus(); + } + } + if (frame.unknown_processing_enabled()) { + if (InstanceOf(result)) { + unknowns.Add(Cast(result)); + } + if (frame.attribute_utility().CheckForUnknown(tmp_attr, + /*use_partial=*/true)) { + unknowns.Add(tmp_attr); + } + } + } + + // Conditionally add if optional. + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_arg = + cel::As(static_cast(result)); + optional_arg) { + if (!optional_arg->HasValue()) { + continue; + } + CEL_RETURN_IF_ERROR(builder->Add(optional_arg->Value())); + continue; + } + return cel::TypeConversionError(result.GetTypeName(), "optional_type") + .NativeValue(); + } + + // Otherwise just add. + CEL_RETURN_IF_ERROR(builder->Add(std::move(result))); + } + + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return absl::OkStatus(); + } + result = std::move(*builder).Build(); + + return absl::OkStatus(); + } + + private: + std::vector> elements_; + absl::flat_hash_set optional_indices_; +}; + +class MutableListStep : public ExpressionStepBase { + public: + explicit MutableListStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; +}; + +absl::Status MutableListStep::Evaluate(ExecutionFrame* frame) const { + CEL_ASSIGN_OR_RETURN(auto builder, + frame->value_manager().NewListValueBuilder( + frame->value_manager().GetDynListType())); + + frame->value_stack().Push(cel::OpaqueValue{ + frame->value_manager().GetMemoryManager().MakeShared( + std::move(builder))}); + return absl::OkStatus(); +} + +class DirectMutableListStep : public DirectExpressionStep { + public: + explicit DirectMutableListStep(int64_t expr_id) + : DirectExpressionStep(expr_id) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; +}; + +absl::Status DirectMutableListStep::Evaluate( + ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + CEL_ASSIGN_OR_RETURN(auto builder, + frame.value_manager().NewListValueBuilder( + frame.value_manager().GetDynListType())); + + result = cel::OpaqueValue{ + frame.value_manager().GetMemoryManager().MakeShared( + std::move(builder))}; + return absl::OkStatus(); +} + } // namespace +std::unique_ptr CreateDirectListStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id) { + return std::make_unique( + std::move(deps), std::move(optional_indices), expr_id); +} + absl::StatusOr> CreateCreateListStep( const cel::ast_internal::CreateList& create_list_expr, int64_t expr_id) { return std::make_unique( - expr_id, create_list_expr.elements().size(), /*immutable=*/true, + expr_id, create_list_expr.elements().size(), MakeOptionalIndicesSet(create_list_expr)); } -absl::StatusOr> CreateCreateMutableListStep( - const cel::ast_internal::CreateList& create_list_expr, int64_t expr_id) { - return std::make_unique( - expr_id, create_list_expr.elements().size(), /*immutable=*/false, - MakeOptionalIndicesSet(create_list_expr)); +std::unique_ptr CreateMutableListStep(int64_t expr_id) { + return std::make_unique(expr_id); +} + +std::unique_ptr CreateDirectMutableListStep( + int64_t expr_id) { + return std::make_unique(expr_id); } } // namespace google::api::expr::runtime diff --git a/eval/eval/create_list_step.h b/eval/eval/create_list_step.h index 36aa32c9d..77e8d0bb3 100644 --- a/eval/eval/create_list_step.h +++ b/eval/eval/create_list_step.h @@ -3,22 +3,37 @@ #include #include +#include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +// Factory method for CreateList that evaluates recursively. +std::unique_ptr CreateDirectListStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id); + // Factory method for CreateList which constructs an immutable list. absl::StatusOr> CreateCreateListStep( const cel::ast_internal::CreateList& create_list_expr, int64_t expr_id); -// Factory method for CreateList which constructs a mutable list as the list -// construction step is generated by a macro AST rewrite rather than by a user -// entered expression. -absl::StatusOr> CreateCreateMutableListStep( - const cel::ast_internal::CreateList& create_list_expr, int64_t expr_id); +// Factory method for CreateList which constructs a mutable list. +// +// This is intended for the list construction step is generated for a +// list-building comprehension (rather than a user authored expression). +std::unique_ptr CreateMutableListStep(int64_t expr_id); + +// Factory method for CreateList which constructs a mutable list. +// +// This is intended for the list construction step is generated for a +// list-building comprehension (rather than a user authored expression). +std::unique_ptr CreateDirectMutableListStep( + int64_t expr_id); } // namespace google::api::expr::runtime diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index c49a22777..b5d583355 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -1,13 +1,25 @@ #include "eval/eval/create_list_step.h" +#include #include #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "base/ast_internal/expr.h" +#include "base/attribute.h" +#include "base/attribute_set.h" +#include "base/type_provider.h" +#include "common/casting.h" +#include "common/memory.h" +#include "common/value.h" +#include "common/value_testing.h" +#include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/internal/interop.h" @@ -17,17 +29,34 @@ #include "eval/public/unknown_attribute_set.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { namespace { +using ::cel::Attribute; +using ::cel::AttributeQualifier; +using ::cel::AttributeSet; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::ListValue; using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; using ::cel::ast_internal::Expr; +using ::cel::test::IntValueIs; using testing::Eq; +using testing::HasSubstr; using testing::Not; +using testing::UnorderedElementsAre; using cel::internal::IsOk; +using cel::internal::IsOkAndHolds; +using cel::internal::StatusIs; // Helper method. Creates simple pipeline containing Select step and runs it. absl::StatusOr RunExpression(const std::vector& values, @@ -195,6 +224,9 @@ TEST_P(CreateListStepTest, CreateListHundred) { } } +INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, + testing::Bool()); + TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { google::protobuf::Arena arena; std::vector values; @@ -220,8 +252,252 @@ TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { EXPECT_THAT(result_set->unknown_attributes().size(), Eq(2)); } -INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, - testing::Bool()); +TEST(CreateDirectListStep, Basic) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); + deps.push_back(CreateConstValueDirectStep(IntValue(2), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).Size(), 2); +} + +TEST(CreateDirectListStep, ForwardFirstError) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep( + value_factory.get().CreateErrorValue(absl::InternalError("test1")), -1)); + deps.push_back(CreateConstValueDirectStep( + value_factory.get().CreateErrorValue(absl::InternalError("test2")), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test1")); +} + +std::vector UnknownAttrNames(const UnknownValue& v) { + std::vector names; + names.reserve(v.attribute_set().size()); + + for (const auto& attr : v.attribute_set()) { + EXPECT_OK(attr.AsString().status()); + names.push_back(attr.AsString().value_or("")); + } + return names; +} + +TEST(CreateDirectListStep, MergeUnknowns) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + AttributeSet attr_set1({Attribute("var1")}); + AttributeSet attr_set2({Attribute("var2")}); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep( + value_factory.get().CreateUnknownValue(std::move(attr_set1)), -1)); + deps.push_back(CreateConstValueDirectStep( + value_factory.get().CreateUnknownValue(std::move(attr_set2)), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(UnknownAttrNames(Cast(result)), + UnorderedElementsAre("var1", "var2")); +} + +TEST(CreateDirectListStep, ErrorBeforeUnknown) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + AttributeSet attr_set1({Attribute("var1")}); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep( + value_factory.get().CreateErrorValue(absl::InternalError("test1")), -1)); + deps.push_back(CreateConstValueDirectStep( + value_factory.get().CreateErrorValue(absl::InternalError("test2")), -1)); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test1")); +} + +class SetAttrDirectStep : public DirectExpressionStep { + public: + explicit SetAttrDirectStep(Attribute attr) + : DirectExpressionStep(-1), attr_(std::move(attr)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attr) const override { + result = frame.value_manager().GetNullValue(); + attr = AttributeTrail(attr_); + return absl::OkStatus(); + } + + private: + cel::Attribute attr_; +}; + +TEST(CreateDirectListStep, MissingAttribute) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.enable_missing_attribute_errors = true; + + activation.SetMissingPatterns({cel::AttributePattern( + "var1", {cel::AttributeQualifierPattern::OfString("field1")})}); + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + std::vector> deps; + deps.push_back( + CreateConstValueDirectStep(value_factory.get().GetNullValue(), -1)); + deps.push_back(std::make_unique( + Attribute("var1", {AttributeQualifier::OfString("field1")}))); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT( + Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("var1.field1"))); +} + +TEST(CreateDirectListStep, OptionalPresentSet) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); + deps.push_back(CreateConstValueDirectStep( + cel::OptionalValue::Of(value_factory.get().GetMemoryManager(), + IntValue(2)), + -1)); + auto step = CreateDirectListStep(std::move(deps), {1}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + auto list = Cast(result); + EXPECT_THAT(list.Size(), Eq(2)); + EXPECT_THAT(list.Get(value_factory.get(), 0), IsOkAndHolds(IntValueIs(1))); + EXPECT_THAT(list.Get(value_factory.get(), 1), IsOkAndHolds(IntValueIs(2))); +} + +TEST(CreateDirectListStep, OptionalAbsentNotSet) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + std::vector> deps; + deps.push_back(CreateConstValueDirectStep(IntValue(1), -1)); + deps.push_back(CreateConstValueDirectStep(cel::OptionalValue::None(), -1)); + auto step = CreateDirectListStep(std::move(deps), {1}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + auto list = Cast(result); + EXPECT_THAT(list.Size(), Eq(1)); + EXPECT_THAT(list.Get(value_factory.get(), 0), IsOkAndHolds(IntValueIs(1))); +} + +TEST(CreateDirectListStep, PartialUnknown) { + cel::ManagedValueFactory value_factory( + cel::TypeProvider::Builtin(), cel::MemoryManagerRef::ReferenceCounting()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + activation.SetUnknownPatterns({cel::AttributePattern( + "var1", {cel::AttributeQualifierPattern::OfString("field1")})}); + + ExecutionFrameBase frame(activation, options, value_factory.get()); + + std::vector> deps; + deps.push_back( + CreateConstValueDirectStep(value_factory.get().CreateIntValue(1), -1)); + deps.push_back(std::make_unique(Attribute("var1", {}))); + auto step = CreateDirectListStep(std::move(deps), {}, -1); + + cel::Value result; + AttributeTrail attr; + + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(UnknownAttrNames(Cast(result)), + UnorderedElementsAre("var1")); +} } // namespace diff --git a/eval/eval/create_map_step.cc b/eval/eval/create_map_step.cc index a96b9ded2..7b70646f3 100644 --- a/eval/eval/create_map_step.cc +++ b/eval/eval/create_map_step.cc @@ -25,12 +25,13 @@ #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "base/ast_internal/expr.h" #include "common/casting.h" #include "common/memory.h" #include "common/type.h" #include "common/value.h" #include "common/value_manager.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" @@ -39,21 +40,13 @@ namespace google::api::expr::runtime { namespace { +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; using ::cel::StructValueBuilderInterface; using ::cel::UnknownValue; using ::cel::Value; -absl::flat_hash_set MakeOptionalIndicesSet( - const cel::ast_internal::CreateStruct& create_struct_expr) { - absl::flat_hash_set optional_indices; - for (size_t i = 0; i < create_struct_expr.entries().size(); ++i) { - if (create_struct_expr.entries()[i].optional_entry()) { - optional_indices.insert(static_cast(i)); - } - } - return optional_indices; -} - // `CreateStruct` implementation for map. class CreateStructStepForMap final : public ExpressionStepBase { public: @@ -133,16 +126,129 @@ absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class DirectCreateMapStep : public DirectExpressionStep { + public: + DirectCreateMapStep(std::vector> deps, + absl::flat_hash_set optional_indices, + int64_t expr_id) + : DirectExpressionStep(expr_id), + deps_(std::move(deps)), + optional_indices_(std::move(optional_indices)), + entry_count_(deps_.size() / 2) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::vector> deps_; + absl::flat_hash_set optional_indices_; + size_t entry_count_; +}; + +absl::Status DirectCreateMapStep::Evaluate( + ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + Value key; + Value value; + AttributeTrail tmp_attr; + auto unknowns = frame.attribute_utility().CreateAccumulator(); + + CEL_ASSIGN_OR_RETURN(auto builder, + frame.value_manager().NewMapValueBuilder( + frame.value_manager().GetDynDynMapType())); + builder->Reserve(entry_count_); + + for (size_t i = 0; i < entry_count_; i += 1) { + int map_key_index = 2 * i; + int map_value_index = map_key_index + 1; + CEL_RETURN_IF_ERROR(deps_[map_key_index]->Evaluate(frame, key, tmp_attr)); + + if (InstanceOf(key)) { + result = key; + return absl::OkStatus(); + } + + if (frame.unknown_processing_enabled()) { + if (InstanceOf(key)) { + unknowns.Add(Cast(key)); + } else if (frame.attribute_utility().CheckForUnknownPartial(tmp_attr)) { + unknowns.Add(tmp_attr); + } + } + + CEL_RETURN_IF_ERROR( + deps_[map_value_index]->Evaluate(frame, value, tmp_attr)); + + if (InstanceOf(value)) { + result = value; + return absl::OkStatus(); + } + + if (frame.unknown_processing_enabled()) { + if (InstanceOf(value)) { + unknowns.Add(Cast(value)); + } else if (frame.attribute_utility().CheckForUnknownPartial(tmp_attr)) { + unknowns.Add(tmp_attr); + } + } + + // Preserve the stack machine behavior of forwarding unknowns before + // errors. + if (!unknowns.IsEmpty()) { + continue; + } + + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_map_value = + cel::As(static_cast(value)); + optional_map_value) { + if (!optional_map_value->HasValue()) { + continue; + } + auto key_status = + builder->Put(std::move(key), optional_map_value->Value()); + if (!key_status.ok()) { + result = frame.value_manager().CreateErrorValue(key_status); + return absl::OkStatus(); + } + continue; + } + return cel::TypeConversionError(value.DebugString(), "optional_type") + .NativeValue(); + } + + CEL_RETURN_IF_ERROR(cel::CheckMapKey(key)); + auto put_status = builder->Put(std::move(key), std::move(value)); + if (!put_status.ok()) { + result = frame.value_manager().CreateErrorValue(put_status); + return absl::OkStatus(); + } + } + + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return absl::OkStatus(); + } + + result = std::move(*builder).Build(); + return absl::OkStatus(); +} + } // namespace +std::unique_ptr CreateDirectCreateMapStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id) { + return std::make_unique( + std::move(deps), std::move(optional_indices), expr_id); +} absl::StatusOr> CreateCreateStructStepForMap( - const cel::ast_internal::CreateStruct& create_struct_expr, + size_t entry_count, absl::flat_hash_set optional_indices, int64_t expr_id) { // Make map-creating step. - return std::make_unique( - expr_id, create_struct_expr.entries().size(), - MakeOptionalIndicesSet(create_struct_expr)); + return std::make_unique(expr_id, entry_count, + std::move(optional_indices)); } } // namespace google::api::expr::runtime diff --git a/eval/eval/create_map_step.h b/eval/eval/create_map_step.h index 56ea3bfc0..f9be4be0c 100644 --- a/eval/eval/create_map_step.h +++ b/eval/eval/create_map_step.h @@ -15,18 +15,30 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_MAP_STEP_H_ +#include #include #include +#include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" -#include "base/ast_internal/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +// Creates an expression step that evaluates a create map expression. +// +// Deps must have an even number of elements, that alternate key, value pairs. +// (key1, value1, key2, value2...). +std::unique_ptr CreateDirectCreateMapStep( + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id); + // Creates an `ExpressionStep` which performs `CreateStruct` for a map. absl::StatusOr> CreateCreateStructStepForMap( - const cel::ast_internal::CreateStruct& create_struct_expr, int64_t expr_id); + size_t entry_count, absl::flat_hash_set optional_indices, + int64_t expr_id); } // namespace google::api::expr::runtime diff --git a/eval/eval/create_map_step_test.cc b/eval/eval/create_map_step_test.cc index e2eea9064..e18f76a21 100644 --- a/eval/eval/create_map_step_test.cc +++ b/eval/eval/create_map_step_test.cc @@ -14,6 +14,7 @@ #include "eval/eval/create_map_step.h" +#include #include #include #include @@ -24,6 +25,7 @@ #include "base/ast_internal/expr.h" #include "base/type_provider.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" @@ -43,17 +45,13 @@ using ::cel::TypeProvider; using ::cel::ast_internal::Expr; using ::google::protobuf::Arena; - -// Helper method. Creates simple pipeline containing CreateStruct step that -// builds Map and runs it. -absl::StatusOr RunCreateMapExpression( +absl::StatusOr CreateStackMachineProgram( const std::vector>& values, - google::protobuf::Arena* arena, bool enable_unknowns) { + Activation& activation) { ExecutionPath path; - Activation activation; - Expr expr0; Expr expr1; + Expr expr0; std::vector exprs; exprs.reserve(values.size() * 2); @@ -86,10 +84,52 @@ absl::StatusOr RunCreateMapExpression( index++; } - CEL_ASSIGN_OR_RETURN(auto step1, - CreateCreateStructStepForMap(create_struct, expr1.id())); + CEL_ASSIGN_OR_RETURN( + auto step1, CreateCreateStructStepForMap(values.size(), {}, expr1.id())); path.push_back(std::move(step1)); + return path; +} + +absl::StatusOr CreateRecursiveProgram( + const std::vector>& values, + Activation& activation) { + ExecutionPath path; + + int index = 0; + std::vector> deps; + for (const auto& item : values) { + std::string key_name = absl::StrCat("key", index); + std::string value_name = absl::StrCat("value", index); + + deps.push_back(CreateDirectIdentStep(key_name, -1)); + + deps.push_back(CreateDirectIdentStep(value_name, -1)); + + activation.InsertValue(key_name, item.first); + activation.InsertValue(value_name, item.second); + + index++; + } + path.push_back(std::make_unique( + CreateDirectCreateMapStep(std::move(deps), {}, -1), -1)); + + return path; +} +// Helper method. Creates simple pipeline containing CreateStruct step that +// builds Map and runs it. +// Equivalent to {key0: value0, ...} +absl::StatusOr RunCreateMapExpression( + const std::vector>& values, + google::protobuf::Arena* arena, bool enable_unknowns, bool enable_recursive_program) { + Activation activation; + + ExecutionPath path; + if (enable_recursive_program) { + CEL_ASSIGN_OR_RETURN(path, CreateRecursiveProgram(values, activation)); + } else { + CEL_ASSIGN_OR_RETURN(path, CreateStackMachineProgram(values, activation)); + } cel::RuntimeOptions options; if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; @@ -101,13 +141,24 @@ absl::StatusOr RunCreateMapExpression( return cel_expr.Evaluate(activation, arena); } -class CreateMapStepTest : public testing::TestWithParam {}; +class CreateMapStepTest + : public testing::TestWithParam> { + public: + bool enable_unknowns() { return std::get<0>(GetParam()); } + bool enable_recursive_program() { return std::get<1>(GetParam()); } + + absl::StatusOr RunMapExpression( + const std::vector>& values, + google::protobuf::Arena* arena) { + return RunCreateMapExpression(values, arena, enable_unknowns(), + enable_recursive_program()); + } +}; // Test that Empty Map is created successfully. TEST_P(CreateMapStepTest, TestCreateEmptyMap) { Arena arena; - ASSERT_OK_AND_ASSIGN(CelValue result, - RunCreateMapExpression({}, &arena, GetParam())); + ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression({}, &arena)); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); @@ -128,7 +179,24 @@ TEST(CreateMapStepTest, TestMapCreateWithUnknown) { CelValue::CreateUnknownSet(&unknown_set)}); ASSERT_OK_AND_ASSIGN(CelValue result, - RunCreateMapExpression(entries, &arena, true)); + RunCreateMapExpression(entries, &arena, true, false)); + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST(CreateMapStepTest, TestMapCreateWithUnknownRecursiveProgram) { + Arena arena; + UnknownSet unknown_set; + std::vector> entries; + + std::vector kKeys = {"test2", "test1"}; + + entries.push_back( + {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); + entries.push_back({CelValue::CreateString(&kKeys[1]), + CelValue::CreateUnknownSet(&unknown_set)}); + + ASSERT_OK_AND_ASSIGN(CelValue result, + RunCreateMapExpression(entries, &arena, true, true)); ASSERT_TRUE(result.IsUnknownSet()); } @@ -145,8 +213,7 @@ TEST_P(CreateMapStepTest, TestCreateStringMap) { entries.push_back( {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); - ASSERT_OK_AND_ASSIGN(CelValue result, - RunCreateMapExpression(entries, &arena, GetParam())); + ASSERT_OK_AND_ASSIGN(CelValue result, RunMapExpression(entries, &arena)); ASSERT_TRUE(result.IsMap()); const CelMap* cel_map = result.MapOrDie(); @@ -163,7 +230,8 @@ TEST_P(CreateMapStepTest, TestCreateStringMap) { EXPECT_EQ(lookup1->Int64OrDie(), 1); } -INSTANTIATE_TEST_SUITE_P(CreateMapStep, CreateMapStepTest, testing::Bool()); +INSTANTIATE_TEST_SUITE_P(CreateMapStep, CreateMapStepTest, + testing::Combine(testing::Bool(), testing::Bool())); } // namespace diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 85007fee1..9de7dca21 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -14,7 +14,6 @@ #include "eval/eval/create_struct_step.h" -#include #include #include #include @@ -31,6 +30,8 @@ #include "common/memory.h" #include "common/value.h" #include "common/value_manager.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" @@ -39,21 +40,13 @@ namespace google::api::expr::runtime { namespace { +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; using ::cel::StructValueBuilderInterface; using ::cel::UnknownValue; using ::cel::Value; -absl::flat_hash_set MakeOptionalIndicesSet( - const cel::ast_internal::CreateStruct& create_struct_expr) { - absl::flat_hash_set optional_indices; - for (size_t i = 0; i < create_struct_expr.entries().size(); ++i) { - if (create_struct_expr.entries()[i].optional_entry()) { - optional_indices.insert(static_cast(i)); - } - } - return optional_indices; -} - // `CreateStruct` implementation for message/struct. class CreateStructStepForStruct final : public ExpressionStepBase { public: @@ -136,26 +129,122 @@ absl::Status CreateStructStepForStruct::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } -} // namespace -absl::StatusOr> CreateCreateStructStepForStruct( - const cel::ast_internal::CreateStruct& create_struct_expr, std::string name, - int64_t expr_id, cel::TypeManager& type_manager) { - // We resolved to a struct type. Use it. - std::vector entries; - entries.reserve(create_struct_expr.entries().size()); - for (const auto& entry : create_struct_expr.entries()) { - CEL_ASSIGN_OR_RETURN(auto field, type_manager.FindStructTypeFieldByName( - name, entry.field_key())); - if (!field.has_value()) { - return absl::InvalidArgumentError(absl::StrCat( - "Invalid message creation: field '", entry.field_key(), - "' not found in '", create_struct_expr.message_name(), "'")); +class DirectCreateStructStep : public DirectExpressionStep { + public: + DirectCreateStructStep( + int64_t expr_id, std::string name, std::vector field_keys, + std::vector> deps, + absl::flat_hash_set optional_indices) + : DirectExpressionStep(expr_id), + name_(std::move(name)), + field_keys_(std::move(field_keys)), + deps_(std::move(deps)), + optional_indices_(std::move(optional_indices)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& trail) const override; + + private: + std::string name_; + std::vector field_keys_; + std::vector> deps_; + absl::flat_hash_set optional_indices_; +}; + +absl::Status DirectCreateStructStep::Evaluate(ExecutionFrameBase& frame, + Value& result, + AttributeTrail& trail) const { + Value field_value; + AttributeTrail field_attr; + auto unknowns = frame.attribute_utility().CreateAccumulator(); + + auto builder_or_status = frame.value_manager().NewValueBuilder(name_); + if (!builder_or_status.ok()) { + result = frame.value_manager().CreateErrorValue(builder_or_status.status()); + return absl::OkStatus(); + } + if (!builder_or_status->has_value()) { + result = frame.value_manager().CreateErrorValue( + absl::NotFoundError(absl::StrCat("Unable to find builder: ", name_))); + return absl::OkStatus(); + } + auto& builder = **builder_or_status; + + for (int i = 0; i < field_keys_.size(); i++) { + CEL_RETURN_IF_ERROR(deps_[i]->Evaluate(frame, field_value, field_attr)); + + // TODO(uncreated-issue/67): if the value is an error, we should be able to return + // early, however some client tests depend on the error message the struct + // impl returns in the stack machine version. + if (InstanceOf(field_value)) { + result = std::move(field_value); + return absl::OkStatus(); + } + + if (frame.unknown_processing_enabled()) { + if (InstanceOf(field_value)) { + unknowns.Add(Cast(field_value)); + } else if (frame.attribute_utility().CheckForUnknownPartial(field_attr)) { + unknowns.Add(field_attr); + } + } + + if (!unknowns.IsEmpty()) { + continue; + } + + if (optional_indices_.contains(static_cast(i))) { + if (auto optional_arg = cel::As( + static_cast(field_value)); + optional_arg) { + if (!optional_arg->HasValue()) { + continue; + } + auto status = + builder->SetFieldByName(field_keys_[i], optional_arg->Value()); + if (!status.ok()) { + result = frame.value_manager().CreateErrorValue(status); + return absl::OkStatus(); + } + } + continue; } - entries.push_back(entry.field_key()); + + auto status = + builder->SetFieldByName(field_keys_[i], std::move(field_value)); + if (!status.ok()) { + result = frame.value_manager().CreateErrorValue(status); + return absl::OkStatus(); + } + } + + if (!unknowns.IsEmpty()) { + result = std::move(unknowns).Build(); + return absl::OkStatus(); } + + result = std::move(*builder).Build(); + return absl::OkStatus(); +} + +} // namespace + +std::unique_ptr CreateDirectCreateStructStep( + std::string resolved_name, std::vector field_keys, + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id) { + return std::make_unique( + expr_id, std::move(resolved_name), std::move(field_keys), std::move(deps), + std::move(optional_indices)); +} + +std::unique_ptr CreateCreateStructStep( + std::string name, std::vector field_keys, + absl::flat_hash_set optional_indices, int64_t expr_id) { + // MakeOptionalIndicesSet(create_struct_expr) return std::make_unique( - expr_id, std::move(name), std::move(entries), - MakeOptionalIndicesSet(create_struct_expr)); + expr_id, std::move(name), std::move(field_keys), + std::move(optional_indices)); } } // namespace google::api::expr::runtime diff --git a/eval/eval/create_struct_step.h b/eval/eval/create_struct_step.h index 73c53be62..eb80634f8 100644 --- a/eval/eval/create_struct_step.h +++ b/eval/eval/create_struct_step.h @@ -18,19 +18,26 @@ #include #include #include +#include -#include "absl/status/statusor.h" -#include "base/ast_internal/expr.h" -#include "common/type_manager.h" +#include "absl/container/flat_hash_set.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { // Creates an `ExpressionStep` which performs `CreateStruct` for a // message/struct. -absl::StatusOr> CreateCreateStructStepForStruct( - const cel::ast_internal::CreateStruct& create_struct_expr, std::string name, - int64_t expr_id, cel::TypeManager& type_manager); +std::unique_ptr CreateDirectCreateStructStep( + std::string name, std::vector field_keys, + std::vector> deps, + absl::flat_hash_set optional_indices, int64_t expr_id); + +// Creates an `ExpressionStep` which performs `CreateStruct` for a +// message/struct. +std::unique_ptr CreateCreateStructStep( + std::string name, std::vector field_keys, + absl::flat_hash_set optional_indices, int64_t expr_id); } // namespace google::api::expr::runtime diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index 7c891b126..6b37c8312 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -29,6 +29,7 @@ #include "base/type_provider.h" #include "common/values/legacy_value_manager.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" @@ -63,14 +64,52 @@ using testing::Eq; using testing::IsNull; using testing::Not; using testing::Pointwise; -using cel::internal::StatusIs; + +absl::StatusOr MakeStackMachinePath(absl::string_view field) { + ExecutionPath path; + Expr expr0; + + auto& ident = expr0.mutable_ident_expr(); + ident.set_name("message"); + CEL_ASSIGN_OR_RETURN(auto step0, CreateIdentStep(ident, expr0.id())); + + auto step1 = CreateCreateStructStep("google.api.expr.runtime.TestMessage", + {std::string(field)}, + /*optional_indices=*/{}, + + /*id=*/-1); + + path.push_back(std::move(step0)); + path.push_back(std::move(step1)); + + return path; +} + +absl::StatusOr MakeRecursivePath(absl::string_view field) { + ExecutionPath path; + + std::vector> deps; + deps.push_back(CreateDirectIdentStep("message", -1)); + + auto step1 = + CreateDirectCreateStructStep("google.api.expr.runtime.TestMessage", + {std::string(field)}, std::move(deps), + /*optional_indices=*/{}, + + /*id=*/-1); + + path.push_back(std::make_unique(std::move(step1), -1)); + + return path; +} + // Helper method. Creates simple pipeline containing CreateStruct step that // builds message and runs it. absl::StatusOr RunExpression(absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, - bool enable_unknowns) { - ExecutionPath path; + bool enable_unknowns, + bool enable_recursive_planning) { CelTypeRegistry type_registry; type_registry.RegisterTypeProvider( std::make_unique( @@ -80,37 +119,26 @@ absl::StatusOr RunExpression(absl::string_view field, cel::common_internal::LegacyValueManager type_manager( memory_manager, type_registry.GetTypeProvider()); - Expr expr0; - Expr expr1; - - auto& ident = expr0.mutable_ident_expr(); - ident.set_name("message"); - CEL_ASSIGN_OR_RETURN(auto step0, CreateIdentStep(ident, expr0.id())); - - auto& create_struct = expr1.mutable_struct_expr(); - create_struct.set_message_name("google.api.expr.runtime.TestMessage"); - - auto& entry = create_struct.mutable_entries().emplace_back(); - entry.set_field_key(std::string(field)); - - CEL_ASSIGN_OR_RETURN(auto maybe_type, - type_manager.FindType(create_struct.message_name())); + CEL_ASSIGN_OR_RETURN( + auto maybe_type, + type_manager.FindType("google.api.expr.runtime.TestMessage")); if (!maybe_type.has_value()) { return absl::Status(absl::StatusCode::kFailedPrecondition, "missing proto message type"); } - CEL_ASSIGN_OR_RETURN(auto step1, - CreateCreateStructStepForStruct( - create_struct, "google.api.expr.runtime.TestMessage", - expr1.id(), type_manager)); - - path.push_back(std::move(step0)); - path.push_back(std::move(step1)); cel::RuntimeOptions options; if (enable_unknowns) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } + ExecutionPath path; + + if (enable_recursive_planning) { + CEL_ASSIGN_OR_RETURN(path, MakeRecursivePath(field)); + } else { + CEL_ASSIGN_OR_RETURN(path, MakeStackMachinePath(field)); + } + CelExpressionFlatImpl cel_expr( FlatExpression(std::move(path), /*comprehension_slot_count=*/0, type_registry.GetTypeProvider(), options)); @@ -122,9 +150,11 @@ absl::StatusOr RunExpression(absl::string_view field, void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, google::protobuf::Arena* arena, TestMessage* test_msg, - bool enable_unknowns) { + bool enable_unknowns, + bool enable_recursive_planning) { ASSERT_OK_AND_ASSIGN(auto result, - RunExpression(field, value, arena, enable_unknowns)); + RunExpression(field, value, arena, enable_unknowns, + enable_recursive_planning)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); @@ -137,13 +167,15 @@ void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, void RunExpressionAndGetMessage(absl::string_view field, std::vector values, google::protobuf::Arena* arena, TestMessage* test_msg, - bool enable_unknowns) { + bool enable_unknowns, + bool enable_recursive_planning) { ContainerBackedListImpl cel_list(std::move(values)); CelValue value = CelValue::CreateList(&cel_list); ASSERT_OK_AND_ASSIGN(auto result, - RunExpression(field, value, arena, enable_unknowns)); + RunExpression(field, value, arena, enable_unknowns, + enable_recursive_planning)); ASSERT_TRUE(result.IsMessage()) << result.DebugString(); const Message* msg = result.MessageOrDie(); @@ -153,7 +185,12 @@ void RunExpressionAndGetMessage(absl::string_view field, test_msg->MergeFrom(*msg); } -class CreateCreateStructStepTest : public testing::TestWithParam {}; +class CreateCreateStructStepTest + : public testing::TestWithParam> { + public: + bool enable_unknowns() { return std::get<0>(GetParam()); } + bool enable_recursive_planning() { return std::get<1>(GetParam()); } +}; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; @@ -166,24 +203,34 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { auto memory_manager = ProtoMemoryManagerRef(&arena); cel::common_internal::LegacyValueManager type_manager( memory_manager, type_registry.GetTypeProvider()); - Expr expr1; - auto& create_struct = expr1.mutable_struct_expr(); - create_struct.set_message_name("google.api.expr.runtime.TestMessage"); - auto adapter = type_registry.FindTypeAdapter(create_struct.message_name()); + auto adapter = + type_registry.FindTypeAdapter("google.api.expr.runtime.TestMessage"); ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); - ASSERT_OK_AND_ASSIGN(auto maybe_type, - type_manager.FindType(create_struct.message_name())); + ASSERT_OK_AND_ASSIGN( + auto maybe_type, + type_manager.FindType("google.api.expr.runtime.TestMessage")); ASSERT_TRUE(maybe_type.has_value()); - ASSERT_OK_AND_ASSIGN(auto step, - CreateCreateStructStepForStruct( - create_struct, "google.api.expr.runtime.TestMessage", - expr1.id(), type_manager)); - path.push_back(std::move(step)); + if (enable_recursive_planning()) { + auto step = + CreateDirectCreateStructStep("google.api.expr.runtime.TestMessage", + /*fields=*/{}, + /*deps=*/{}, + /*optional_indices=*/{}, + /*id=*/-1); + path.push_back( + std::make_unique(std::move(step), /*id=*/-1)); + } else { + auto step = CreateCreateStructStep("google.api.expr.runtime.TestMessage", + /*fields=*/{}, + /*optional_indices=*/{}, + /*id=*/-1); + path.push_back(std::move(step)); + } cel::RuntimeOptions options; - if (GetParam()) { + if (enable_unknowns(), enable_recursive_planning()) { options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; } CelExpressionFlatImpl cel_expr( @@ -199,49 +246,30 @@ TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); } -TEST_P(CreateCreateStructStepTest, TestMessageCreationBadField) { - ExecutionPath path; - CelTypeRegistry type_registry; - type_registry.RegisterTypeProvider( - std::make_unique( - google::protobuf::DescriptorPool::generated_pool(), - google::protobuf::MessageFactory::generated_factory())); - google::protobuf::Arena arena; - auto memory_manager = ProtoMemoryManagerRef(&arena); - cel::common_internal::LegacyValueManager type_manager( - memory_manager, type_registry.GetTypeProvider()); - Expr expr1; - - auto& create_struct = expr1.mutable_struct_expr(); - create_struct.set_message_name("google.api.expr.runtime.TestMessage"); - auto& entry = create_struct.mutable_entries().emplace_back(); - entry.set_field_key("bad_field"); - auto& value = entry.mutable_value(); - value.mutable_const_expr().set_bool_value(true); - auto adapter = type_registry.FindTypeAdapter(create_struct.message_name()); - ASSERT_TRUE(adapter.has_value() && adapter->mutation_apis() != nullptr); +// Test message creation if unknown argument is passed +TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { + Arena arena; + TestMessage test_msg; + UnknownSet unknown_set; - ASSERT_OK_AND_ASSIGN(auto maybe_type, - type_manager.FindType(create_struct.message_name())); - ASSERT_TRUE(maybe_type.has_value()); - EXPECT_THAT(CreateCreateStructStepForStruct( - create_struct, "google.api.expr.runtime.TestMessage", - expr1.id(), type_manager) - .status(), - StatusIs(absl::StatusCode::kInvalidArgument, - testing::HasSubstr("'bad_field'"))); + auto eval_status = + RunExpression("bool_value", CelValue::CreateUnknownSet(&unknown_set), + &arena, true, /*enable_recursive_planning=*/false); + ASSERT_OK(eval_status); + ASSERT_TRUE(eval_status->IsUnknownSet()); } // Test message creation if unknown argument is passed -TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { +TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknownRecursive) { Arena arena; TestMessage test_msg; UnknownSet unknown_set; - auto eval_status = RunExpression( - "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true); + auto eval_status = + RunExpression("bool_value", CelValue::CreateUnknownSet(&unknown_set), + &arena, true, /*enable_recursive_planning=*/true); ASSERT_OK(eval_status); - ASSERT_TRUE(eval_status->IsUnknownSet()); + ASSERT_TRUE(eval_status->IsUnknownSet()) << eval_status->DebugString(); } // Test that fields of type bool are set correctly @@ -250,7 +278,8 @@ TEST_P(CreateCreateStructStepTest, TestSetBoolField) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_value", CelValue::CreateBool(true), &arena, &test_msg, GetParam())); + "bool_value", CelValue::CreateBool(true), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.bool_value(), true); } @@ -260,7 +289,8 @@ TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); + "int32_value", CelValue::CreateInt64(1), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int32_value(), 1); } @@ -270,9 +300,9 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint32_value", CelValue::CreateUint64(1), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "uint32_value", CelValue::CreateUint64(1), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint32_value(), 1); } @@ -283,7 +313,8 @@ TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); + "int64_value", CelValue::CreateInt64(1), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.int64_value(), 1); } @@ -293,9 +324,9 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint64_value", CelValue::CreateUint64(1), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "uint64_value", CelValue::CreateUint64(1), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.uint64_value(), 1); } @@ -305,9 +336,9 @@ TEST_P(CreateCreateStructStepTest, TestSetFloatField) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("float_value", CelValue::CreateDouble(2.0), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "float_value", CelValue::CreateDouble(2.0), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.float_value(), 2.0); } @@ -317,9 +348,9 @@ TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("double_value", CelValue::CreateDouble(2.0), - &arena, &test_msg, GetParam())); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "double_value", CelValue::CreateDouble(2.0), &arena, &test_msg, + enable_unknowns(), enable_recursive_planning())); EXPECT_DOUBLE_EQ(test_msg.double_value(), 2.0); } @@ -332,7 +363,7 @@ TEST_P(CreateCreateStructStepTest, TestSetStringField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg, - GetParam())); + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.string_value(), kTestStr); } @@ -346,7 +377,7 @@ TEST_P(CreateCreateStructStepTest, TestSetBytesField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg, - GetParam())); + enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.bytes_value(), kTestStr); } @@ -361,7 +392,7 @@ TEST_P(CreateCreateStructStepTest, TestSetDurationField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "duration_value", CelProtoWrapper::CreateDuration(&test_duration), &arena, - &test_msg, GetParam())); + &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); } @@ -376,7 +407,7 @@ TEST_P(CreateCreateStructStepTest, TestSetTimestampField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "timestamp_value", CelProtoWrapper::CreateTimestamp(&test_timestamp), - &arena, &test_msg, GetParam())); + &arena, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); } @@ -393,7 +424,7 @@ TEST_P(CreateCreateStructStepTest, TestSetMessageField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "message_value", CelProtoWrapper::CreateMessage(&orig_msg, &arena), - &arena, &test_msg, GetParam())); + &arena, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg.message_value(), EqualsProto(orig_msg)); } @@ -413,7 +444,7 @@ TEST_P(CreateCreateStructStepTest, TestSetAnyField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "any_value", CelProtoWrapper::CreateMessage(&orig_embedded_msg, &arena), - &arena, &test_msg, GetParam())); + &arena, &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_THAT(test_msg, EqualsProto(orig_msg)); TestMessage test_embedded_msg; @@ -428,7 +459,7 @@ TEST_P(CreateCreateStructStepTest, TestSetEnumField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), &arena, - &test_msg, GetParam())); + &test_msg, enable_unknowns(), enable_recursive_planning())); EXPECT_EQ(test_msg.enum_value(), TestMessage::TEST_ENUM_2); } @@ -444,7 +475,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_list", values, &arena, &test_msg, GetParam())); + "bool_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.bool_list(), Pointwise(Eq(), kValues)); } @@ -460,7 +492,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_list", values, &arena, &test_msg, GetParam())); + "int32_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.int32_list(), Pointwise(Eq(), kValues)); } @@ -476,7 +509,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint32_list", values, &arena, &test_msg, GetParam())); + "uint32_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.uint32_list(), Pointwise(Eq(), kValues)); } @@ -492,7 +526,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_list", values, &arena, &test_msg, GetParam())); + "int64_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.int64_list(), Pointwise(Eq(), kValues)); } @@ -508,7 +543,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_list", values, &arena, &test_msg, GetParam())); + "uint64_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.uint64_list(), Pointwise(Eq(), kValues)); } @@ -524,7 +560,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "float_list", values, &arena, &test_msg, GetParam())); + "float_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.float_list(), Pointwise(Eq(), kValues)); } @@ -540,7 +577,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "double_list", values, &arena, &test_msg, GetParam())); + "double_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.double_list(), Pointwise(Eq(), kValues)); } @@ -556,7 +594,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_list", values, &arena, &test_msg, GetParam())); + "string_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.string_list(), Pointwise(Eq(), kValues)); } @@ -572,7 +611,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_list", values, &arena, &test_msg, GetParam())); + "bytes_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.bytes_list(), Pointwise(Eq(), kValues)); } @@ -591,7 +631,8 @@ TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { } ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "message_list", values, &arena, &test_msg, GetParam())); + "message_list", values, &arena, &test_msg, enable_unknowns(), + enable_recursive_planning())); ASSERT_THAT(test_msg.message_list()[0], EqualsProto(kValues[0])); ASSERT_THAT(test_msg.message_list()[1], EqualsProto(kValues[1])); } @@ -617,7 +658,7 @@ TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - GetParam())); + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.string_int32_map().size(), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[0]), 2); @@ -644,7 +685,7 @@ TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - GetParam())); + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.int64_int32_map().size(), 2); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[0]), 1); @@ -671,7 +712,7 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, - GetParam())); + enable_unknowns(), enable_recursive_planning())); ASSERT_EQ(test_msg.uint64_int32_map().size(), 2); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[0]), 1); @@ -679,7 +720,7 @@ TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { } INSTANTIATE_TEST_SUITE_P(CombinedCreateStructTest, CreateCreateStructStepTest, - testing::Bool()); + testing::Combine(testing::Bool(), testing::Bool())); } // namespace diff --git a/eval/eval/direct_expression_step.cc b/eval/eval/direct_expression_step.cc new file mode 100644 index 000000000..2d7fc6fc0 --- /dev/null +++ b/eval/eval/direct_expression_step.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "eval/eval/direct_expression_step.h" + +#include + +#include "absl/status/status.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/evaluator_core.h" +#include "internal/status_macros.h" + +namespace google::api::expr::runtime { + +absl::Status WrappedDirectStep::Evaluate(ExecutionFrame* frame) const { + cel::Value result; + AttributeTrail attribute_trail; + CEL_RETURN_IF_ERROR(impl_->Evaluate(*frame, result, attribute_trail)); + frame->value_stack().Push(std::move(result), std::move(attribute_trail)); + return absl::OkStatus(); +} + +} // namespace google::api::expr::runtime diff --git a/eval/eval/direct_expression_step.h b/eval/eval/direct_expression_step.h new file mode 100644 index 000000000..f11479065 --- /dev/null +++ b/eval/eval/direct_expression_step.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/evaluator_core.h" + +namespace google::api::expr::runtime { + +// Represents a directly evaluated CEL expression. +// +// Subexpressions assign to values on the C++ program stack and call their +// dependencies directly. +// +// This reduces the setup overhead for evaluation and minimizes value churn +// to / from a heap based value stack managed by the CEL runtime, but can't be +// used for arbitrarily nested expressions. +class DirectExpressionStep { + public: + explicit DirectExpressionStep(int64_t expr_id) : expr_id_(expr_id) {} + DirectExpressionStep() : expr_id_(-1) {} + + virtual ~DirectExpressionStep() = default; + + int64_t expr_id() const { return expr_id_; } + bool comes_from_ast() const { return expr_id_ >= 0; } + + virtual absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const = 0; + + // Return a type id for this node. + // + // Users must not make any assumptions about the type if the default value is + // returned. + virtual cel::NativeTypeId GetNativeTypeId() const { + return cel::NativeTypeId(); + } + + // Implementations optionally support inspecting the program tree. + virtual absl::optional> + GetDependencies() const { + return absl::nullopt; + } + + // Implementations optionally support extracting the program tree. + // + // Extract prevents the callee from functioning, and is only intended for use + // when replacing a given expression step. + virtual absl::optional>> + ExtractDependencies() { + return absl::nullopt; + }; + + protected: + int64_t expr_id_; +}; + +// Wrapper for direct steps to work with the stack machine impl. +class WrappedDirectStep : public ExpressionStep { + public: + WrappedDirectStep(std::unique_ptr impl, int64_t expr_id) + : ExpressionStep(expr_id, false), impl_(std::move(impl)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + const DirectExpressionStep* wrapped() const { return impl_.get(); } + + private: + std::unique_ptr impl_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_DIRECT_EXPRESSION_STEP_H_ diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index c09308365..2902976d6 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -24,15 +24,14 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/utility/utility.h" #include "base/type_provider.h" #include "common/memory.h" #include "common/value.h" #include "common/value_manager.h" -#include "internal/status_macros.h" #include "runtime/activation_interface.h" +#include "runtime/managed_value_factory.h" namespace google::api::expr::runtime { @@ -121,7 +120,7 @@ class EvaluationStatus final { } // namespace absl::StatusOr ExecutionFrame::Evaluate( - EvaluationListener listener) { + EvaluationListener& listener) { const size_t initial_stack_size = value_stack().size(); if (!listener) { @@ -185,9 +184,15 @@ absl::StatusOr FlatExpression::EvaluateWithCallback( FlatExpressionEvaluatorState& state) const { state.Reset(); - ExecutionFrame frame(subexpressions_, activation, options_, state); + ExecutionFrame frame(subexpressions_, activation, options_, state, + std::move(listener)); - return frame.Evaluate(std::move(listener)); + return frame.Evaluate(frame.callback()); +} + +cel::ManagedValueFactory FlatExpression::MakeValueFactory( + cel::MemoryManagerRef memory_manager) const { + return cel::ManagedValueFactory(type_provider_, memory_manager); } } // namespace google::api::expr::runtime diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index ed900563d..6ef8bc77f 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -21,6 +21,7 @@ #include #include +#include "absl/base/nullability.h" #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -135,10 +136,102 @@ class FlatExpressionEvaluatorState { cel::ValueManager* value_factory_; }; +// Context needed for evaluation. This is sufficient for supporting stackless +// recursive evaluation, but larger expressions require a full execution frame. +class ExecutionFrameBase { + public: + // Overload for test usages. + ExecutionFrameBase(const cel::ActivationInterface& activation, + const cel::RuntimeOptions& options, + cel::ValueManager& value_manager) + : activation_(&activation), + callback_(), + options_(&options), + value_manager_(&value_manager), + attribute_utility_(activation.GetUnknownAttributes(), + activation.GetMissingAttributes(), value_manager), + slots_(&ComprehensionSlots::GetEmptyInstance()), + max_iterations_(options.comprehension_max_iterations), + iterations_(0) {} + + ExecutionFrameBase(const cel::ActivationInterface& activation, + EvaluationListener callback, + const cel::RuntimeOptions& options, + cel::ValueManager& value_manager, + ComprehensionSlots& slots) + : activation_(&activation), + callback_(std::move(callback)), + options_(&options), + value_manager_(&value_manager), + attribute_utility_(activation.GetUnknownAttributes(), + activation.GetMissingAttributes(), value_manager), + slots_(&slots), + max_iterations_(options.comprehension_max_iterations), + iterations_(0) {} + + const cel::ActivationInterface& activation() const { return *activation_; } + + EvaluationListener& callback() { return callback_; } + + const cel::RuntimeOptions& options() const { return *options_; } + + cel::ValueManager& value_manager() { return *value_manager_; } + + const AttributeUtility& attribute_utility() const { + return attribute_utility_; + } + + bool attribute_tracking_enabled() const { + return options_->unknown_processing != + cel::UnknownProcessingOptions::kDisabled || + options_->enable_missing_attribute_errors; + } + + bool missing_attribute_errors_enabled() const { + return options_->enable_missing_attribute_errors; + } + + bool unknown_processing_enabled() const { + return options_->unknown_processing != + cel::UnknownProcessingOptions::kDisabled; + } + + bool unknown_function_results_enabled() const { + return options_->unknown_processing == + cel::UnknownProcessingOptions::kAttributeAndFunction; + } + + ComprehensionSlots& comprehension_slots() { return *slots_; } + + // Increment iterations and return an error if the iteration budget is + // exceeded + absl::Status IncrementIterations() { + if (max_iterations_ == 0) { + return absl::OkStatus(); + } + iterations_++; + if (iterations_ >= max_iterations_) { + return absl::Status(absl::StatusCode::kInternal, + "Iteration budget exceeded"); + } + return absl::OkStatus(); + } + + protected: + absl::Nonnull activation_; + EvaluationListener callback_; + absl::Nonnull options_; + absl::Nonnull value_manager_; + AttributeUtility attribute_utility_; + absl::Nonnull slots_; + const int max_iterations_; + int iterations_; +}; + // ExecutionFrame manages the context needed for expression evaluation. // The lifecycle of the object is bound to a FlateExpression::Evaluate*(...) // call. -class ExecutionFrame { +class ExecutionFrame : public ExecutionFrameBase { public: // flat is the flattened sequence of execution steps that will be evaluated. // activation provides bindings between parameter names and values. @@ -147,33 +240,25 @@ class ExecutionFrame { ExecutionFrame(ExecutionPathView flat, const cel::ActivationInterface& activation, const cel::RuntimeOptions& options, - FlatExpressionEvaluatorState& state) - : pc_(0UL), + FlatExpressionEvaluatorState& state, + EvaluationListener callback = EvaluationListener()) + : ExecutionFrameBase(activation, std::move(callback), options, + state.value_manager(), state.comprehension_slots()), + pc_(0UL), execution_path_(flat), - activation_(activation), - options_(options), state_(state), - attribute_utility_(activation_.GetUnknownAttributes(), - activation_.GetMissingAttributes(), - state_.value_factory()), - max_iterations_(options_.comprehension_max_iterations), - iterations_(0), subexpressions_() {} ExecutionFrame(absl::Span subexpressions, const cel::ActivationInterface& activation, const cel::RuntimeOptions& options, - FlatExpressionEvaluatorState& state) - : pc_(0UL), + FlatExpressionEvaluatorState& state, + EvaluationListener callback = EvaluationListener()) + : ExecutionFrameBase(activation, std::move(callback), options, + state.value_manager(), state.comprehension_slots()), + pc_(0UL), execution_path_(subexpressions[0]), - activation_(activation), - options_(options), state_(state), - attribute_utility_(activation_.GetUnknownAttributes(), - activation_.GetMissingAttributes(), - state_.value_factory()), - max_iterations_(options_.comprehension_max_iterations), - iterations_(0), subexpressions_(subexpressions) { ABSL_DCHECK(!subexpressions.empty()); } @@ -182,7 +267,9 @@ class ExecutionFrame { const ExpressionStep* Next(); // Evaluate the execution frame to completion. - absl::StatusOr Evaluate(EvaluationListener listener); + absl::StatusOr Evaluate(EvaluationListener& listener); + // Evaluate the execution frame to completion. + absl::StatusOr Evaluate() { return Evaluate(callback()); } // Intended for use in builtin shortcutting operations. // @@ -222,36 +309,27 @@ class ExecutionFrame { } EvaluatorStack& value_stack() { return state_.value_stack(); } - ComprehensionSlots& comprehension_slots() { - return state_.comprehension_slots(); - } bool enable_attribute_tracking() const { - return options_.unknown_processing != - cel::UnknownProcessingOptions::kDisabled || - options_.enable_missing_attribute_errors; + return attribute_tracking_enabled(); } - bool enable_unknowns() const { - return options_.unknown_processing != - cel::UnknownProcessingOptions::kDisabled; - } + bool enable_unknowns() const { return unknown_processing_enabled(); } bool enable_unknown_function_results() const { - return options_.unknown_processing == - cel::UnknownProcessingOptions::kAttributeAndFunction; + return unknown_function_results_enabled(); } bool enable_missing_attribute_errors() const { - return options_.enable_missing_attribute_errors; + return missing_attribute_errors_enabled(); } bool enable_heterogeneous_numeric_lookups() const { - return options_.enable_heterogeneous_equality; + return options().enable_heterogeneous_equality; } bool enable_comprehension_list_append() const { - return options_.enable_comprehension_list_append; + return options().enable_comprehension_list_append; } cel::MemoryManagerRef memory_manager() { return state_.memory_manager(); } @@ -262,29 +340,9 @@ class ExecutionFrame { cel::ValueManager& value_factory() { return state_.value_factory(); } - cel::ValueManager& value_manager() { return state_.value_factory(); } - - const AttributeUtility& attribute_utility() const { - return attribute_utility_; - } - // Returns reference to the modern API activation. const cel::ActivationInterface& modern_activation() const { - return activation_; - } - - // Increment iterations and return an error if the iteration budget is - // exceeded - absl::Status IncrementIterations() { - if (max_iterations_ == 0) { - return absl::OkStatus(); - } - iterations_++; - if (iterations_ >= max_iterations_) { - return absl::Status(absl::StatusCode::kInternal, - "Iteration budget exceeded"); - } - return absl::OkStatus(); + return *activation_; } private: @@ -296,12 +354,7 @@ class ExecutionFrame { size_t pc_; // pc_ - Program Counter. Current position on execution path. ExecutionPathView execution_path_; - const cel::ActivationInterface& activation_; - const cel::RuntimeOptions& options_; // owned by the FlatExpr instance FlatExpressionEvaluatorState& state_; - AttributeUtility attribute_utility_; - const int max_iterations_; - int iterations_; absl::Span subexpressions_; std::vector call_stack_; }; @@ -355,8 +408,19 @@ class FlatExpression { const cel::ActivationInterface& activation, EvaluationListener listener, FlatExpressionEvaluatorState& state) const; + cel::ManagedValueFactory MakeValueFactory( + cel::MemoryManagerRef memory_manager) const; + const ExecutionPath& path() const { return path_; } + absl::Span subexpressions() const { + return subexpressions_; + } + + const cel::RuntimeOptions& options() const { return options_; } + + size_t comprehension_slots_size() const { return comprehension_slots_size_; } + private: ExecutionPath path_; std::vector subexpressions_; diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index ed81f9b3f..f75323099 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -8,16 +8,21 @@ #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "base/ast_internal/expr.h" #include "base/function.h" #include "base/function_descriptor.h" #include "base/kind.h" +#include "common/casting.h" #include "common/value.h" #include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" @@ -69,6 +74,32 @@ bool ArgumentKindsMatch(const cel::FunctionDescriptor& descriptor, return true; } +// Adjust new type names to legacy equivalent. int -> int64_t. +// Temporary fix to migrate value types without breaking clients. +// TODO(uncreated-issue/46): Update client tests that depend on this value. +std::string ToLegacyKindName(absl::string_view type_name) { + if (type_name == "int" || type_name == "uint") { + return absl::StrCat(type_name, "64"); + } + + return std::string(type_name); +} + +std::string CallArgTypeString(absl::Span args) { + std::string call_sig_string = ""; + + for (size_t i = 0; i < args.size(); i++) { + const auto& arg = args[i]; + if (!call_sig_string.empty()) { + absl::StrAppend(&call_sig_string, ", "); + } + absl::StrAppend( + &call_sig_string, + ToLegacyKindName(cel::KindToString(ValueKindToKind(arg->kind())))); + } + return absl::StrCat("(", call_sig_string, ")"); +} + // Convert partially unknown arguments to unknowns before passing to the // function. // TODO(issues/52): See if this can be refactored to remove the eager @@ -109,17 +140,6 @@ bool IsUnknownFunctionResultError(const Value& result) { return payload.has_value() && payload.value() == "true"; } -// Adjust new type names to legacy equivalent. int -> int64_t. -// Temporary fix to migrate value types without breaking clients. -// TODO(uncreated-issue/46): Update client tests that depend on this value. -std::string ToLegacyKindName(absl::string_view type_name) { - if (type_name == "int" || type_name == "uint") { - return absl::StrCat(type_name, "64"); - } - - return std::string(type_name); -} - // Simple wrapper around a function resolution result. A function call should // resolve to a single function implementation and a descriptor or none. using ResolveResult = absl::optional; @@ -155,6 +175,52 @@ class AbstractFunctionStep : public ExpressionStepBase { size_t num_arguments_; }; +inline absl::StatusOr Invoke( + const cel::FunctionOverloadReference& overload, int64_t expr_id, + absl::Span args, ExecutionFrameBase& frame) { + FunctionEvaluationContext context(frame.value_manager()); + + CEL_ASSIGN_OR_RETURN(Value result, + overload.implementation.Invoke(context, args)); + + if (frame.unknown_function_results_enabled() && + IsUnknownFunctionResultError(result)) { + return frame.attribute_utility().CreateUnknownSet(overload.descriptor, + expr_id, args); + } + return result; +} + +Value NoOverloadResult(absl::string_view name, + absl::Span args, + ExecutionFrameBase& frame) { + // No matching overloads. + // Such absence can be caused by presence of CelError in arguments. + // To enable behavior of functions that accept CelError( &&, || ), CelErrors + // should be propagated along execution path. + for (size_t i = 0; i < args.size(); i++) { + const auto& arg = args[i]; + if (cel::InstanceOf(arg)) { + return arg; + } + } + + if (frame.unknown_processing_enabled()) { + // Already converted partial unknowns to unknown sets so just merge. + absl::optional unknown_set = + frame.attribute_utility().MergeUnknowns(args); + if (unknown_set.has_value()) { + return *unknown_set; + } + } + + // If no errors or unknowns in input args, create new CelError for missing + // overload. + return frame.value_manager().CreateErrorValue( + cel::runtime_internal::CreateNoMatchingOverloadError( + absl::StrCat(name, CallArgTypeString(args)))); +} + absl::StatusOr AbstractFunctionStep::DoEvaluate( ExecutionFrame* frame) const { // Create Span object that contains input arguments to the function. @@ -176,53 +242,10 @@ absl::StatusOr AbstractFunctionStep::DoEvaluate( // Overload found and is allowed to consume the arguments. if (matched_function.has_value() && ShouldAcceptOverload(matched_function->descriptor, input_args)) { - FunctionEvaluationContext context(frame->value_factory()); - - CEL_ASSIGN_OR_RETURN(Value result, matched_function->implementation.Invoke( - context, input_args)); - - if (frame->enable_unknown_function_results() && - IsUnknownFunctionResultError(result)) { - return frame->attribute_utility().CreateUnknownSet( - matched_function->descriptor, id(), input_args); - } - return result; - } - - // No matching overloads. - // Such absence can be caused by presence of CelError in arguments. - // To enable behavior of functions that accept CelError( &&, || ), CelErrors - // should be propagated along execution path. - for (const auto& arg : input_args) { - if (arg->Is()) { - return arg; - } + return Invoke(*matched_function, id(), input_args, *frame); } - if (frame->enable_unknowns()) { - // Already converted partial unknowns to unknown sets so just merge. - absl::optional unknown_set = - frame->attribute_utility().MergeUnknowns(input_args); - if (unknown_set.has_value()) { - return *unknown_set; - } - } - - std::string arg_types; - for (const auto& arg : input_args) { - if (!arg_types.empty()) { - absl::StrAppend(&arg_types, ", "); - } - absl::StrAppend( - &arg_types, - ToLegacyKindName(cel::KindToString(ValueKindToKind(arg->kind())))); - } - - // If no errors or unknowns in input args, create new CelError for missing - // overload. - return frame->value_factory().CreateErrorValue( - cel::runtime_internal::CreateNoMatchingOverloadError( - absl::StrCat(name_, "(", arg_types, ")"))); + return NoOverloadResult(name_, input_args, *frame); } absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { @@ -240,27 +263,12 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } -class EagerFunctionStep : public AbstractFunctionStep { - public: - EagerFunctionStep(std::vector overloads, - const std::string& name, size_t num_args, int64_t expr_id) - : AbstractFunctionStep(name, num_args, expr_id), - overloads_(std::move(overloads)) {} - - absl::StatusOr ResolveFunction( - absl::Span input_args, - const ExecutionFrame* frame) const override; - - private: - std::vector overloads_; -}; - -absl::StatusOr EagerFunctionStep::ResolveFunction( +absl::StatusOr ResolveStatic( absl::Span input_args, - const ExecutionFrame* frame) const { + absl::Span overloads) { ResolveResult result = absl::nullopt; - for (const auto& overload : overloads_) { + for (const auto& overload : overloads) { if (ArgumentKindsMatch(overload.descriptor, input_args)) { // More than one overload matches our arguments. if (result.has_value()) { @@ -274,6 +282,63 @@ absl::StatusOr EagerFunctionStep::ResolveFunction( return result; } +absl::StatusOr ResolveLazy( + absl::Span input_args, absl::string_view name, + bool receiver_style, + absl::Span providers, + const ExecutionFrameBase& frame) { + ResolveResult result = absl::nullopt; + + std::vector arg_types(input_args.size()); + + std::transform( + input_args.begin(), input_args.end(), arg_types.begin(), + [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); + + cel::FunctionDescriptor matcher{name, receiver_style, arg_types}; + + const cel::ActivationInterface& activation = frame.activation(); + for (auto provider : providers) { + // The LazyFunctionStep has so far only resolved by function shape, check + // that the runtime argument kinds agree with the specific descriptor for + // the provider candidates. + if (!ArgumentKindsMatch(provider.descriptor, input_args)) { + continue; + } + + CEL_ASSIGN_OR_RETURN(auto overload, + provider.provider.GetFunction(matcher, activation)); + if (overload.has_value()) { + // More than one overload matches our arguments. + if (result.has_value()) { + return absl::Status(absl::StatusCode::kInternal, + "Cannot resolve overloads"); + } + + result.emplace(overload.value()); + } + } + + return result; +} + +class EagerFunctionStep : public AbstractFunctionStep { + public: + EagerFunctionStep(std::vector overloads, + const std::string& name, size_t num_args, int64_t expr_id) + : AbstractFunctionStep(name, num_args, expr_id), + overloads_(std::move(overloads)) {} + + absl::StatusOr ResolveFunction( + absl::Span input_args, + const ExecutionFrame* frame) const override { + return ResolveStatic(input_args, overloads_); + } + + private: + std::vector overloads_; +}; + class LazyFunctionStep : public AbstractFunctionStep { public: // Constructs LazyFunctionStep that attempts to lookup function implementation @@ -298,43 +363,136 @@ class LazyFunctionStep : public AbstractFunctionStep { absl::StatusOr LazyFunctionStep::ResolveFunction( absl::Span input_args, const ExecutionFrame* frame) const { - ResolveResult result = absl::nullopt; + return ResolveLazy(input_args, name_, receiver_style_, providers_, *frame); +} - std::vector arg_types(num_arguments_); +class StaticResolver { + public: + explicit StaticResolver(std::vector overloads) + : overloads_(std::move(overloads)) {} - std::transform( - input_args.begin(), input_args.end(), arg_types.begin(), - [](const cel::Value& value) { return ValueKindToKind(value->kind()); }); + absl::StatusOr Resolve(ExecutionFrameBase& frame, + absl::Span input) const { + return ResolveStatic(input, overloads_); + } - cel::FunctionDescriptor matcher{name_, receiver_style_, arg_types}; + private: + std::vector overloads_; +}; - const cel::ActivationInterface& activation = frame->modern_activation(); - for (auto provider : providers_) { - // The LazyFunctionStep has so far only resolved by function shape, check - // that the runtime argument kinds agree with the specific descriptor for - // the provider candidates. - if (!ArgumentKindsMatch(provider.descriptor, input_args)) { - continue; +class LazyResolver { + public: + explicit LazyResolver( + std::vector providers, + std::string name, bool receiver_style) + : providers_(std::move(providers)), + name_(std::move(name)), + receiver_style_(receiver_style) {} + + absl::StatusOr Resolve(ExecutionFrameBase& frame, + absl::Span input) const { + return ResolveLazy(input, name_, receiver_style_, providers_, frame); + } + + private: + std::vector providers_; + std::string name_; + bool receiver_style_; +}; + +template +class DirectFunctionStepImpl : public DirectExpressionStep { + public: + DirectFunctionStepImpl( + int64_t expr_id, const std::string& name, + std::vector> arg_steps, + Resolver&& resolver) + : DirectExpressionStep(expr_id), + name_(name), + arg_steps_(std::move(arg_steps)), + resolver_(std::forward(resolver)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& trail) const override { + absl::InlinedVector args; + absl::InlinedVector arg_trails; + + args.resize(arg_steps_.size()); + arg_trails.resize(arg_steps_.size()); + + for (size_t i = 0; i < arg_steps_.size(); i++) { + CEL_RETURN_IF_ERROR( + arg_steps_[i]->Evaluate(frame, args[i], arg_trails[i])); } - CEL_ASSIGN_OR_RETURN(auto overload, - provider.provider.GetFunction(matcher, activation)); - if (overload.has_value()) { - // More than one overload matches our arguments. - if (result.has_value()) { - return absl::Status(absl::StatusCode::kInternal, - "Cannot resolve overloads"); + if (frame.unknown_processing_enabled()) { + for (size_t i = 0; i < arg_trails.size(); i++) { + if (frame.attribute_utility().CheckForUnknown(arg_trails[i], + /*use_partial=*/true)) { + args[i] = frame.attribute_utility().CreateUnknownSet( + arg_trails[i].attribute()); + } } + } - result.emplace(overload.value()); + CEL_ASSIGN_OR_RETURN(ResolveResult resolved_function, + resolver_.Resolve(frame, args)); + + if (resolved_function.has_value() && + ShouldAcceptOverload(resolved_function->descriptor, args)) { + CEL_ASSIGN_OR_RETURN(result, + Invoke(*resolved_function, expr_id_, args, frame)); + + return absl::OkStatus(); } + + result = NoOverloadResult(name_, args, frame); + + return absl::OkStatus(); } - return result; -} + absl::optional> GetDependencies() + const override { + std::vector dependencies; + dependencies.reserve(arg_steps_.size()); + for (const auto& arg_step : arg_steps_) { + dependencies.push_back(arg_step.get()); + } + return dependencies; + } + + absl::optional>> + ExtractDependencies() override { + return std::move(arg_steps_); + } + + private: + friend Resolver; + std::string name_; + std::vector> arg_steps_; + Resolver resolver_; +}; } // namespace +std::unique_ptr CreateDirectFunctionStep( + int64_t expr_id, const cel::ast_internal::Call& call, + std::vector> deps, + std::vector overloads) { + return std::make_unique>( + expr_id, call.function(), std::move(deps), + StaticResolver(std::move(overloads))); +} + +std::unique_ptr CreateDirectLazyFunctionStep( + int64_t expr_id, const cel::ast_internal::Call& call, + std::vector> deps, + std::vector providers) { + return std::make_unique>( + expr_id, call.function(), std::move(deps), + LazyResolver(std::move(providers), call.function(), call.has_target())); +} + absl::StatusOr> CreateFunctionStep( const cel::ast_internal::Call& call_expr, int64_t expr_id, std::vector lazy_overloads) { diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index f8317fde8..99444e3ab 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -7,11 +7,29 @@ #include "absl/status/statusor.h" #include "base/ast_internal/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" +#include "runtime/function_overload_reference.h" #include "runtime/function_registry.h" namespace google::api::expr::runtime { +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of eagerly functions configured in the +// CelFunctionRegistry. +std::unique_ptr CreateDirectFunctionStep( + int64_t expr_id, const cel::ast_internal::Call& call, + std::vector> deps, + std::vector overloads); + +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of lazy functions configured in the +// CelFunctionRegistry. +std::unique_ptr CreateDirectLazyFunctionStep( + int64_t expr_id, const cel::ast_internal::Call& call, + std::vector> deps, + std::vector providers); + // Factory method for Call-based execution step where the function will be // resolved at runtime (lazily) from an input Activation. absl::StatusOr> CreateFunctionStep( diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 9109c0d8b..a5f0e53b8 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -1,6 +1,7 @@ #include "eval/eval/function_step.h" #include +#include #include #include #include @@ -8,9 +9,12 @@ #include "absl/strings/string_view.h" #include "base/ast_internal/expr.h" +#include "base/builtins.h" #include "base/type_provider.h" +#include "common/kind.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/internal/interop.h" @@ -24,9 +28,14 @@ #include "eval/public/structs/cel_proto_wrapper.h" #include "eval/public/testing/matchers.h" #include "eval/testutil/test_message.pb.h" -#include "internal/status_macros.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/testing.h" +#include "runtime/function_overload_reference.h" +#include "runtime/function_registry.h" +#include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" +#include "runtime/standard_functions.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { @@ -38,6 +47,7 @@ using ::cel::ast_internal::Expr; using ::cel::ast_internal::Ident; using testing::Eq; using testing::Not; +using testing::Truly; using cel::internal::IsOk; using cel::internal::StatusIs; @@ -196,6 +206,17 @@ std::vector ArgumentMatcher(const Call& call) { : call.args().size()); } +std::unique_ptr CreateExpressionImpl( + const cel::RuntimeOptions& options, + std::unique_ptr expr) { + ExecutionPath path; + path.push_back(std::make_unique(std::move(expr), -1)); + + return std::make_unique( + FlatExpression(std::move(path), /*comprehension_slot_count=*/0, + TypeProvider::Builtin(), options)); +} + absl::StatusOr> MakeTestFunctionStep( const Call& call, const CelFunctionRegistry& registry) { auto argument_matcher = ArgumentMatcher(call); @@ -962,5 +983,185 @@ TEST(FunctionStepStrictnessTest, IfFunctionNonStrictAndGivenUnknownInvokesIt) { ASSERT_THAT(value, test::IsCelInt64(Eq(0))); } +class DirectFunctionStepTest : public testing::Test { + public: + DirectFunctionStepTest() + : value_factory_(TypeProvider::Builtin(), + cel::extensions::ProtoMemoryManagerRef(&arena_)) {} + + void SetUp() override { + ASSERT_OK(cel::RegisterStandardFunctions(registry_, options_)); + } + + std::vector GetOverloads( + absl::string_view name, int64_t arguments_size) { + std::vector matcher; + matcher.resize(arguments_size, cel::Kind::kAny); + return registry_.FindStaticOverloads(name, false, matcher); + } + + // Helper for shorthand constructing direct expr deps. + // + // Works around copies in init-list construction. + std::vector> MakeDeps( + std::unique_ptr dep, + std::unique_ptr dep2) { + std::vector> result; + result.reserve(2); + result.push_back(std::move(dep)); + result.push_back(std::move(dep2)); + return result; + }; + + protected: + cel::FunctionRegistry registry_; + cel::RuntimeOptions options_; + google::protobuf::Arena arena_; + cel::ManagedValueFactory value_factory_; +}; + +TEST_F(DirectFunctionStepTest, SimpleCall) { + value_factory_.get().CreateIntValue(1); + + cel::ast_internal::Call call; + call.set_function(cel::builtin::kAdd); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + + std::vector> deps; + deps.push_back( + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1))); + deps.push_back( + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1))); + + auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), + GetOverloads(cel::builtin::kAdd, 2)); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, test::IsCelInt64(2)); +} + +TEST_F(DirectFunctionStepTest, RecursiveCall) { + value_factory_.get().CreateIntValue(1); + + cel::ast_internal::Call call; + call.set_function(cel::builtin::kAdd); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + + auto overloads = GetOverloads(cel::builtin::kAdd, 2); + + auto MakeLeaf = [&]() { + return CreateDirectFunctionStep( + -1, call, + MakeDeps( + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1)), + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1))), + overloads); + }; + + auto expr = CreateDirectFunctionStep( + -1, call, + MakeDeps(CreateDirectFunctionStep( + -1, call, MakeDeps(MakeLeaf(), MakeLeaf()), overloads), + CreateDirectFunctionStep( + -1, call, MakeDeps(MakeLeaf(), MakeLeaf()), overloads)), + overloads); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, test::IsCelInt64(8)); +} + +TEST_F(DirectFunctionStepTest, ErrorHandlingCall) { + value_factory_.get().CreateIntValue(1); + + cel::ast_internal::Call add_call; + add_call.set_function(cel::builtin::kAdd); + add_call.mutable_args().emplace_back(); + add_call.mutable_args().emplace_back(); + + cel::ast_internal::Call div_call; + div_call.set_function(cel::builtin::kDivide); + div_call.mutable_args().emplace_back(); + div_call.mutable_args().emplace_back(); + + auto add_overloads = GetOverloads(cel::builtin::kAdd, 2); + auto div_overloads = GetOverloads(cel::builtin::kDivide, 2); + + auto error_expr = CreateDirectFunctionStep( + -1, div_call, + MakeDeps( + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1)), + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(0))), + div_overloads); + + auto expr = CreateDirectFunctionStep( + -1, add_call, + MakeDeps( + std::move(error_expr), + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1))), + add_overloads); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, + test::IsCelError(StatusIs(absl::StatusCode::kInvalidArgument, + testing::HasSubstr("divide by zero")))); +} + +TEST_F(DirectFunctionStepTest, NoOverload) { + value_factory_.get().CreateIntValue(1); + + cel::ast_internal::Call call; + call.set_function(cel::builtin::kAdd); + call.mutable_args().emplace_back(); + call.mutable_args().emplace_back(); + + std::vector> deps; + deps.push_back( + CreateConstValueDirectStep(value_factory_.get().CreateIntValue(1))); + deps.push_back(CreateConstValueDirectStep( + value_factory_.get().CreateUncheckedStringValue("2"))); + + auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), + GetOverloads(cel::builtin::kAdd, 2)); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, Truly(CheckNoMatchingOverloadError)); +} + +TEST_F(DirectFunctionStepTest, NoOverload0Args) { + value_factory_.get().CreateIntValue(1); + + cel::ast_internal::Call call; + call.set_function(cel::builtin::kAdd); + + std::vector> deps; + auto expr = CreateDirectFunctionStep(-1, call, std::move(deps), + GetOverloads(cel::builtin::kAdd, 2)); + + auto plan = CreateExpressionImpl(options_, std::move(expr)); + + Activation activation; + ASSERT_OK_AND_ASSIGN(auto value, plan->Evaluate(activation, &arena_)); + + EXPECT_THAT(value, Truly(CheckNoMatchingOverloadError)); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index aab002b82..dd0c561b3 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -9,8 +9,12 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/ast_internal/expr.h" +#include "common/value.h" #include "eval/eval/attribute_trail.h" #include "eval/eval/comprehension_slots.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" @@ -23,7 +27,6 @@ namespace { using ::cel::Value; using ::cel::ValueView; using ::cel::runtime_internal::CreateError; -using ::cel::runtime_internal::CreateMissingAttributeError; class IdentStep : public ExpressionStepBase { public: @@ -38,68 +41,76 @@ class IdentStep : public ExpressionStepBase { AttributeTrail trail; }; - absl::StatusOr DoEvaluate(ExecutionFrame* frame, - Value& scratch) const; - std::string name_; }; -absl::StatusOr IdentStep::DoEvaluate( - ExecutionFrame* frame, Value& scratch) const { - IdentResult result; - // Populate trails if either MissingAttributeError or UnknownPattern - // is enabled. - if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { - result.trail = AttributeTrail(name_); - } - - if (frame->enable_missing_attribute_errors() && !name_.empty() && - frame->attribute_utility().CheckForMissingAttribute(result.trail)) { - scratch = frame->value_factory().CreateErrorValue( - CreateMissingAttributeError(name_)); - result.value = scratch; - return result; - } - - if (frame->enable_unknowns()) { - if (frame->attribute_utility().CheckForUnknown(result.trail, false)) { - scratch = - frame->attribute_utility().CreateUnknownSet(result.trail.attribute()); - result.value = scratch; - return result; +absl::Status LookupIdent(const std::string& name, ExecutionFrameBase& frame, + Value& result, AttributeTrail& attribute) { + if (frame.attribute_tracking_enabled()) { + attribute = AttributeTrail(name); + if (frame.missing_attribute_errors_enabled() && + frame.attribute_utility().CheckForMissingAttribute(attribute)) { + CEL_ASSIGN_OR_RETURN( + result, frame.attribute_utility().CreateMissingAttributeError( + attribute.attribute())); + return absl::OkStatus(); + } + if (frame.unknown_processing_enabled() && + frame.attribute_utility().CheckForUnknownExact(attribute)) { + result = + frame.attribute_utility().CreateUnknownSet(attribute.attribute()); + return absl::OkStatus(); } } - CEL_ASSIGN_OR_RETURN(auto value, frame->modern_activation().FindVariable( - frame->value_factory(), name_, scratch)); + CEL_ASSIGN_OR_RETURN(auto value, frame.activation().FindVariable( + frame.value_manager(), name, result)); if (value.has_value()) { - result.value = *value; - return result; + result = *value; + return absl::OkStatus(); } - scratch = frame->value_factory().CreateErrorValue(CreateError( - absl::StrCat("No value with name \"", name_, "\" found in Activation"))); - result.value = scratch; + result = frame.value_manager().CreateErrorValue(CreateError( + absl::StrCat("No value with name \"", name, "\" found in Activation"))); - return result; + return absl::OkStatus(); } absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { - Value scratch; - CEL_ASSIGN_OR_RETURN(IdentResult result, DoEvaluate(frame, scratch)); + Value value; + AttributeTrail attribute; + + CEL_RETURN_IF_ERROR(LookupIdent(name_, *frame, value, attribute)); - frame->value_stack().Push(Value{result.value}, std::move(result.trail)); + frame->value_stack().Push(std::move(value), std::move(attribute)); return absl::OkStatus(); } +absl::StatusOr> LookupSlot( + absl::string_view name, size_t slot_index, ExecutionFrameBase& frame) { + const ComprehensionSlots::Slot* slot = + frame.comprehension_slots().Get(slot_index); + if (slot == nullptr) { + return absl::InternalError( + absl::StrCat("Comprehension variable accessed out of scope: ", name)); + } + return slot; +} + class SlotStep : public ExpressionStepBase { public: SlotStep(absl::string_view name, size_t slot_index, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name), slot_index_(slot_index) {} - absl::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override { + CEL_ASSIGN_OR_RETURN(const ComprehensionSlots::Slot* slot, + LookupSlot(name_, slot_index_, *frame)); + + frame->value_stack().Push(slot->value, slot->attribute); + return absl::OkStatus(); + } private: std::string name_; @@ -107,20 +118,57 @@ class SlotStep : public ExpressionStepBase { size_t slot_index_; }; -absl::Status SlotStep::Evaluate(ExecutionFrame* frame) const { - const ComprehensionSlots::Slot* slot = - frame->comprehension_slots().Get(slot_index_); - if (slot == nullptr) { - return absl::InternalError( - absl::StrCat("Comprehension variable accessed out of scope: ", name_)); +class DirectIdentStep : public DirectExpressionStep { + public: + DirectIdentStep(absl::string_view name, int64_t expr_id) + : DirectExpressionStep(expr_id), name_(name) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + return LookupIdent(name_, frame, result, attribute); } - frame->value_stack().Push(slot->value, slot->attribute); - return absl::OkStatus(); -} + private: + std::string name_; +}; + +class DirectSlotStep : public DirectExpressionStep { + public: + DirectSlotStep(std::string name, size_t slot_index, int64_t expr_id) + : DirectExpressionStep(expr_id), + name_(std::move(name)), + slot_index_(slot_index) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_ASSIGN_OR_RETURN(const ComprehensionSlots::Slot* slot, + LookupSlot(name_, slot_index_, frame)); + + if (frame.attribute_tracking_enabled()) { + attribute = slot->attribute; + } + result = slot->value; + return absl::OkStatus(); + } + + private: + std::string name_; + size_t slot_index_; +}; } // namespace +std::unique_ptr CreateDirectIdentStep( + absl::string_view identifier, int64_t expr_id) { + return std::make_unique(identifier, expr_id); +} + +std::unique_ptr CreateDirectSlotIdentStep( + absl::string_view identifier, size_t slot_index, int64_t expr_id) { + return std::make_unique(std::string(identifier), slot_index, + expr_id); +} + absl::StatusOr> CreateIdentStep( const cel::ast_internal::Ident& ident_expr, int64_t expr_id) { return std::make_unique(ident_expr.name(), expr_id); diff --git a/eval/eval/ident_step.h b/eval/eval/ident_step.h index a3a6da934..ab943737b 100644 --- a/eval/eval/ident_step.h +++ b/eval/eval/ident_step.h @@ -5,11 +5,19 @@ #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "base/ast_internal/expr.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +std::unique_ptr CreateDirectIdentStep( + absl::string_view identifier, int64_t expr_id); + +std::unique_ptr CreateDirectSlotIdentStep( + absl::string_view identifier, size_t slot_index, int64_t expr_id); + // Factory method for Ident - based Execution step absl::StatusOr> CreateIdentStep( const cel::ast_internal::Ident& ident, int64_t expr_id); diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index f21b447ca..6901edec3 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -5,21 +5,40 @@ #include #include +#include "absl/status/status.h" #include "base/type_provider.h" +#include "common/memory.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" #include "eval/eval/evaluator_core.h" #include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" #include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" namespace google::api::expr::runtime { namespace { +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::ManagedValueFactory; +using ::cel::MemoryManagerRef; +using ::cel::RuntimeOptions; using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; using ::cel::ast_internal::Expr; using ::google::protobuf::Arena; using testing::Eq; +using testing::HasSubstr; +using testing::SizeIs; +using cel::internal::StatusIs; TEST(IdentStepTest, TestIdentStep) { Expr expr; @@ -197,6 +216,91 @@ TEST(IdentStepTest, TestIdentStepUnknownAttribute) { ASSERT_TRUE(result.IsUnknownSet()); } +TEST(DirectIdentStepTest, Basic) { + ManagedValueFactory value_factory(TypeProvider::Builtin(), + MemoryManagerRef::ReferenceCounting()); + cel::Activation activation; + RuntimeOptions options; + + activation.InsertOrAssignValue("var1", IntValue(42)); + + ExecutionFrameBase frame(activation, options, value_factory.get()); + Value result; + AttributeTrail trail; + + auto step = CreateDirectIdentStep("var1", -1); + + ASSERT_OK(step->Evaluate(frame, result, trail)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), Eq(42)); +} + +TEST(DirectIdentStepTest, UnknownAttribute) { + ManagedValueFactory value_factory(TypeProvider::Builtin(), + MemoryManagerRef::ReferenceCounting()); + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + activation.InsertOrAssignValue("var1", IntValue(42)); + activation.SetUnknownPatterns({CreateCelAttributePattern("var1", {})}); + + ExecutionFrameBase frame(activation, options, value_factory.get()); + Value result; + AttributeTrail trail; + + auto step = CreateDirectIdentStep("var1", -1); + + ASSERT_OK(step->Evaluate(frame, result, trail)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).attribute_set(), SizeIs(1)); +} + +TEST(DirectIdentStepTest, MissingAttribute) { + ManagedValueFactory value_factory(TypeProvider::Builtin(), + MemoryManagerRef::ReferenceCounting()); + cel::Activation activation; + RuntimeOptions options; + options.enable_missing_attribute_errors = true; + + activation.InsertOrAssignValue("var1", IntValue(42)); + activation.SetMissingPatterns({CreateCelAttributePattern("var1", {})}); + + ExecutionFrameBase frame(activation, options, value_factory.get()); + Value result; + AttributeTrail trail; + + auto step = CreateDirectIdentStep("var1", -1); + + ASSERT_OK(step->Evaluate(frame, result, trail)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("var1"))); +} + +TEST(DirectIdentStepTest, NotFound) { + ManagedValueFactory value_factory(TypeProvider::Builtin(), + MemoryManagerRef::ReferenceCounting()); + cel::Activation activation; + RuntimeOptions options; + + ExecutionFrameBase frame(activation, options, value_factory.get()); + Value result; + AttributeTrail trail; + + auto step = CreateDirectIdentStep("var1", -1); + + ASSERT_OK(step->Evaluate(frame, result, trail)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("\"var1\" found in Activation"))); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/lazy_init_step.cc b/eval/eval/lazy_init_step.cc index da06c7bb7..07758f225 100644 --- a/eval/eval/lazy_init_step.cc +++ b/eval/eval/lazy_init_step.cc @@ -17,15 +17,24 @@ #include #include #include +#include +#include "google/api/expr/v1alpha1/value.pb.h" +#include "absl/base/nullability.h" #include "absl/status/status.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { +using ::cel::Value; + class CheckLazyInitStep : public ExpressionStepBase { public: CheckLazyInitStep(size_t slot_index, size_t subexpression_index, @@ -52,6 +61,57 @@ class CheckLazyInitStep : public ExpressionStepBase { size_t subexpression_index_; }; +class DirectCheckLazyInitStep : public DirectExpressionStep { + public: + DirectCheckLazyInitStep(size_t slot_index, + const DirectExpressionStep* subexpression, + int64_t expr_id) + : DirectExpressionStep(expr_id), + slot_index_(slot_index), + subexpression_(subexpression) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + auto* slot = frame.comprehension_slots().Get(slot_index_); + if (slot != nullptr) { + result = slot->value; + attribute = slot->attribute; + return absl::OkStatus(); + } + + CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); + frame.comprehension_slots().Set(slot_index_, result, attribute); + + return absl::OkStatus(); + } + + private: + size_t slot_index_; + absl::Nonnull subexpression_; +}; + +class BindStep : public DirectExpressionStep { + public: + BindStep(size_t slot_index, + std::unique_ptr subexpression, int64_t expr_id) + : DirectExpressionStep(expr_id), + slot_index_(slot_index), + subexpression_(std::move(subexpression)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_RETURN_IF_ERROR(subexpression_->Evaluate(frame, result, attribute)); + + frame.comprehension_slots().ClearSlot(slot_index_); + + return absl::OkStatus(); + } + + private: + size_t slot_index_; + std::unique_ptr subexpression_; +}; + class AssignSlotStep : public ExpressionStepBase { public: explicit AssignSlotStep(size_t slot_index, bool should_pop) @@ -95,6 +155,19 @@ class ClearSlotStep : public ExpressionStepBase { } // namespace +std::unique_ptr CreateDirectBindStep( + size_t slot_index, std::unique_ptr expression, + int64_t expr_id) { + return std::make_unique(slot_index, std::move(expression), expr_id); +} + +std::unique_ptr CreateDirectLazyInitStep( + size_t slot_index, absl::Nonnull subexpression, + int64_t expr_id) { + return std::make_unique(slot_index, subexpression, + expr_id); +} + std::unique_ptr CreateCheckLazyInitStep( size_t slot_index, size_t subexpression_index, int64_t expr_id) { return std::make_unique(slot_index, subexpression_index, diff --git a/eval/eval/lazy_init_step.h b/eval/eval/lazy_init_step.h index 5733afa7f..c7c593e4c 100644 --- a/eval/eval/lazy_init_step.h +++ b/eval/eval/lazy_init_step.h @@ -41,10 +41,23 @@ #include #include +#include "absl/base/nullability.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +// Creates a step representing a Bind expression. +std::unique_ptr CreateDirectBindStep( + size_t slot_index, std::unique_ptr expression, + int64_t expr_id); + +// Creates a direct step representing accessing a lazily evaluated alias from +// a bind or block. +std::unique_ptr CreateDirectLazyInitStep( + size_t slot_index, absl::Nonnull subexpression, + int64_t expr_id); + // Creates a guard step that checks that an alias is initialized. // If it is, push to stack and jump to the step that depends on the value. // Otherwise, run the initialization routine (which pushes the value to top of diff --git a/eval/eval/lazy_init_step_test.cc b/eval/eval/lazy_init_step_test.cc index a6eede774..9cf322662 100644 --- a/eval/eval/lazy_init_step_test.cc +++ b/eval/eval/lazy_init_step_test.cc @@ -80,7 +80,7 @@ TEST_F(LazyInitStepTest, CreateCheckInitStepDoesInit) { ExecutionFrame frame(expression_table, activation_, runtime_options_, evaluator_state_); - ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate(EvaluationListener())); + ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate()); EXPECT_TRUE(value->Is() && value->As().NativeValue() == 42); @@ -107,7 +107,7 @@ TEST_F(LazyInitStepTest, CreateCheckInitStepSkipInit) { ExecutionFrame frame(expression_table, activation_, runtime_options_, evaluator_state_); frame.comprehension_slots().Set(0, value_factory().CreateIntValue(42)); - ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate(EvaluationListener())); + ASSERT_OK_AND_ASSIGN(auto value, frame.Evaluate()); EXPECT_TRUE(value->Is() && value->As().NativeValue() == 42); @@ -124,7 +124,7 @@ TEST_F(LazyInitStepTest, CreateAssignSlotStepBasic) { frame.value_stack().Push(value_factory().CreateIntValue(42)); // This will error because no return value, step will still evaluate. - frame.Evaluate(EvaluationListener()).IgnoreError(); + frame.Evaluate().IgnoreError(); auto* slot = frame.comprehension_slots().Get(0); ASSERT_TRUE(slot != nullptr); @@ -144,7 +144,7 @@ TEST_F(LazyInitStepTest, CreateAssignSlotAndPopStepBasic) { frame.value_stack().Push(value_factory().CreateIntValue(42)); // This will error because no return value, step will still evaluate. - frame.Evaluate(EvaluationListener()).IgnoreError(); + frame.Evaluate().IgnoreError(); auto* slot = frame.comprehension_slots().Get(0); ASSERT_TRUE(slot != nullptr); @@ -161,7 +161,7 @@ TEST_F(LazyInitStepTest, CreateAssignSlotStepStackUnderflow) { ExecutionFrame frame(path, activation_, runtime_options_, evaluator_state_); frame.comprehension_slots().ClearSlot(0); - EXPECT_THAT(frame.Evaluate(EvaluationListener()), + EXPECT_THAT(frame.Evaluate(), StatusIs(absl::StatusCode::kInternal, HasSubstr("Stack underflow assigning lazy value"))); } @@ -175,7 +175,7 @@ TEST_F(LazyInitStepTest, CreateClearSlotStepBasic) { frame.comprehension_slots().Set(0, value_factory().CreateIntValue(42)); // This will error because no return value, step will still evaluate. - frame.Evaluate(EvaluationListener()).IgnoreError(); + frame.Evaluate().IgnoreError(); auto* slot = frame.comprehension_slots().Get(0); ASSERT_TRUE(slot == nullptr); diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index ba4fcac23..1f0407383 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -1,15 +1,25 @@ #include "eval/eval/logic_step.h" +#include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "base/builtins.h" +#include "common/casting.h" #include "common/value.h" +#include "common/value_kind.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" +#include "internal/status_macros.h" +#include "runtime/internal/errors.h" namespace google::api::expr::runtime { @@ -17,19 +27,173 @@ namespace { using ::cel::BoolValue; using ::cel::BoolValueView; - +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::UnknownValue; using ::cel::Value; +using ::cel::ValueKind; using ::cel::ValueView; using ::cel::runtime_internal::CreateNoMatchingOverloadError; -class LogicalOpStep : public ExpressionStepBase { +enum class OpType { kAnd, kOr }; + +// Shared logic for the fall through case (we didn't see the shortcircuit +// value). +absl::Status ReturnLogicResult(ExecutionFrameBase& frame, OpType op_type, + Value& lhs_result, Value& rhs_result, + AttributeTrail& attribute_trail, + AttributeTrail& rhs_attr) { + ValueKind lhs_kind = lhs_result.kind(); + ValueKind rhs_kind = rhs_result.kind(); + + if (frame.unknown_processing_enabled()) { + if (lhs_kind == ValueKind::kUnknown && rhs_kind == ValueKind::kUnknown) { + lhs_result = frame.attribute_utility().MergeUnknownValues( + Cast(lhs_result), Cast(rhs_result)); + // Clear attribute trail so this doesn't get re-identified as a new + // unknown and reset the accumulated attributes. + attribute_trail = AttributeTrail(); + return absl::OkStatus(); + } else if (lhs_kind == ValueKind::kUnknown) { + return absl::OkStatus(); + } else if (rhs_kind == ValueKind::kUnknown) { + lhs_result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + } + + if (lhs_kind == ValueKind::kError) { + return absl::OkStatus(); + } else if (rhs_kind == ValueKind::kError) { + lhs_result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + + if (lhs_kind == ValueKind::kBool && rhs_kind == ValueKind::kBool) { + return absl::OkStatus(); + } + + // Otherwise, add a no overload error. + attribute_trail = AttributeTrail(); + lhs_result = + frame.value_manager().CreateErrorValue(CreateNoMatchingOverloadError( + op_type == OpType::kOr ? cel::builtin::kOr : cel::builtin::kAnd)); + return absl::OkStatus(); +} + +class ExhaustiveDirectLogicStep : public DirectExpressionStep { public: - enum class OpType { AND, OR }; + explicit ExhaustiveDirectLogicStep(std::unique_ptr lhs, + std::unique_ptr rhs, + OpType op_type, int64_t expr_id) + : DirectExpressionStep(expr_id), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)), + op_type_(op_type) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::unique_ptr lhs_; + std::unique_ptr rhs_; + OpType op_type_; +}; + +absl::Status ExhaustiveDirectLogicStep::Evaluate( + ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute_trail) const { + CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, attribute_trail)); + ValueKind lhs_kind = result.kind(); + + Value rhs_result; + AttributeTrail rhs_attr; + CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, attribute_trail)); + + ValueKind rhs_kind = rhs_result.kind(); + if (lhs_kind == ValueKind::kBool) { + bool lhs_bool = Cast(result).NativeValue(); + if ((op_type_ == OpType::kOr && lhs_bool) || + (op_type_ == OpType::kAnd && !lhs_bool)) { + return absl::OkStatus(); + } + } + + if (rhs_kind == ValueKind::kBool) { + bool rhs_bool = Cast(rhs_result).NativeValue(); + if ((op_type_ == OpType::kOr && rhs_bool) || + (op_type_ == OpType::kAnd && !rhs_bool)) { + result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + } + + return ReturnLogicResult(frame, op_type_, result, rhs_result, attribute_trail, + rhs_attr); +} +class DirectLogicStep : public DirectExpressionStep { + public: + explicit DirectLogicStep(std::unique_ptr lhs, + std::unique_ptr rhs, + OpType op_type, int64_t expr_id) + : DirectExpressionStep(expr_id), + lhs_(std::move(lhs)), + rhs_(std::move(rhs)), + op_type_(op_type) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute_trail) const override; + + private: + std::unique_ptr lhs_; + std::unique_ptr rhs_; + OpType op_type_; +}; + +absl::Status DirectLogicStep::Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute_trail) const { + CEL_RETURN_IF_ERROR(lhs_->Evaluate(frame, result, attribute_trail)); + ValueKind lhs_kind = result.kind(); + if (lhs_kind == ValueKind::kBool) { + bool lhs_bool = Cast(result).NativeValue(); + if ((op_type_ == OpType::kOr && lhs_bool) || + (op_type_ == OpType::kAnd && !lhs_bool)) { + return absl::OkStatus(); + } + } + + Value rhs_result; + AttributeTrail rhs_attr; + + CEL_RETURN_IF_ERROR(rhs_->Evaluate(frame, rhs_result, attribute_trail)); + + ValueKind rhs_kind = rhs_result.kind(); + + if (rhs_kind == ValueKind::kBool) { + bool rhs_bool = Cast(rhs_result).NativeValue(); + if ((op_type_ == OpType::kOr && rhs_bool) || + (op_type_ == OpType::kAnd && !rhs_bool)) { + result = std::move(rhs_result); + attribute_trail = std::move(rhs_attr); + return absl::OkStatus(); + } + } + + return ReturnLogicResult(frame, op_type_, result, rhs_result, attribute_trail, + rhs_attr); +} + +class LogicalOpStep : public ExpressionStepBase { + public: // Constructs FunctionStep that uses overloads specified. LogicalOpStep(OpType op_type, int64_t expr_id) : ExpressionStepBase(expr_id), op_type_(op_type) { - shortcircuit_ = (op_type_ == OpType::OR); + shortcircuit_ = (op_type_ == OpType::kOr); } absl::Status Evaluate(ExecutionFrame* frame) const override; @@ -52,9 +216,9 @@ class LogicalOpStep : public ExpressionStepBase { if (has_bool_args[0] && has_bool_args[1]) { switch (op_type_) { - case OpType::AND: + case OpType::kAnd: return BoolValueView{bool_args[0] && bool_args[1]}; - case OpType::OR: + case OpType::kOr: return BoolValueView{bool_args[0] || bool_args[1]}; } } @@ -82,7 +246,8 @@ class LogicalOpStep : public ExpressionStepBase { // Fallback. scratch = frame->value_factory().CreateErrorValue(CreateNoMatchingOverloadError( - (op_type_ == OpType::OR) ? cel::builtin::kOr : cel::builtin::kAnd)); + (op_type_ == OpType::kOr) ? cel::builtin::kOr + : cel::builtin::kAnd)); return scratch; } @@ -107,14 +272,42 @@ absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { } // namespace +// Factory method for "And" Execution step +std::unique_ptr CreateDirectAndStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting) { + if (shortcircuiting) { + return std::make_unique(std::move(lhs), std::move(rhs), + OpType::kAnd, expr_id); + } else { + return std::make_unique( + std::move(lhs), std::move(rhs), OpType::kAnd, expr_id); + } +} + +// Factory method for "Or" Execution step +std::unique_ptr CreateDirectOrStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting) { + if (shortcircuiting) { + return std::make_unique(std::move(lhs), std::move(rhs), + OpType::kOr, expr_id); + } else { + return std::make_unique( + std::move(lhs), std::move(rhs), OpType::kOr, expr_id); + } +} + // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id) { - return std::make_unique(LogicalOpStep::OpType::AND, expr_id); + return std::make_unique(OpType::kAnd, expr_id); } // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id) { - return std::make_unique(LogicalOpStep::OpType::OR, expr_id); + return std::make_unique(OpType::kOr, expr_id); } } // namespace google::api::expr::runtime diff --git a/eval/eval/logic_step.h b/eval/eval/logic_step.h index e626f9857..6f490435c 100644 --- a/eval/eval/logic_step.h +++ b/eval/eval/logic_step.h @@ -5,10 +5,23 @@ #include #include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +// Factory method for "And" Execution step +std::unique_ptr CreateDirectAndStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting); + +// Factory method for "Or" Execution step +std::unique_ptr CreateDirectOrStep( + std::unique_ptr lhs, + std::unique_ptr rhs, int64_t expr_id, + bool shortcircuiting); + // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id); diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc index f26dd922c..9aa1bc4e7 100644 --- a/eval/eval/logic_step_test.cc +++ b/eval/eval/logic_step_test.cc @@ -1,25 +1,58 @@ #include "eval/eval/logic_step.h" #include +#include +#include #include - +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "base/ast_internal/expr.h" +#include "base/attribute.h" +#include "base/attribute_set.h" #include "base/type_provider.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_manager.h" +#include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { +using ::cel::Attribute; +using ::cel::AttributeSet; +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::ManagedValueFactory; using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; +using ::cel::ValueManager; using ::cel::ast_internal::Expr; +using ::cel::extensions::ProtoMemoryManagerRef; using ::google::protobuf::Arena; using testing::Eq; @@ -308,6 +341,244 @@ TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); + +enum class Op { kAnd, kOr }; + +enum class OpArg { + kTrue, + kFalse, + kUnknown, + kError, + // Arbitrary incorrect type + kInt +}; + +enum class OpResult { + kTrue, + kFalse, + kUnknown, + kError, +}; + +struct TestCase { + std::string name; + Op op; + OpArg arg0; + OpArg arg1; + OpResult result; +}; + +class DirectLogicStepTest + : public testing::TestWithParam> { + public: + DirectLogicStepTest() + : value_factory_(TypeProvider::Builtin(), + ProtoMemoryManagerRef(&arena_)) {} + + bool ShortcircuitingEnabled() { return std::get<0>(GetParam()); } + const TestCase& GetTestCase() { return std::get<1>(GetParam()); } + + ValueManager& value_manager() { return value_factory_.get(); } + + UnknownValue MakeUnknownValue(std::string attr) { + std::vector attrs; + attrs.push_back(Attribute(std::move(attr))); + return value_manager().CreateUnknownValue(AttributeSet(attrs)); + } + + protected: + Arena arena_; + ManagedValueFactory value_factory_; +}; + +TEST_P(DirectLogicStepTest, TestCases) { + const TestCase& test_case = GetTestCase(); + + auto MakeArg = + [&](OpArg arg, + absl::string_view name) -> std::unique_ptr { + switch (arg) { + case OpArg::kTrue: + return CreateConstValueDirectStep(BoolValue(true)); + case OpArg::kFalse: + return CreateConstValueDirectStep(BoolValue(false)); + case OpArg::kUnknown: + return CreateConstValueDirectStep(MakeUnknownValue(std::string(name))); + case OpArg::kError: + return CreateConstValueDirectStep( + value_manager().CreateErrorValue(absl::InternalError(name))); + case OpArg::kInt: + return CreateConstValueDirectStep(IntValue(42)); + } + }; + + std::unique_ptr lhs = MakeArg(test_case.arg0, "lhs"); + std::unique_ptr rhs = MakeArg(test_case.arg1, "rhs"); + + std::unique_ptr op = + (test_case.op == Op::kAnd) + ? CreateDirectAndStep(std::move(lhs), std::move(rhs), -1, + ShortcircuitingEnabled()) + : CreateDirectOrStep(std::move(lhs), std::move(rhs), -1, + ShortcircuitingEnabled()); + + cel::Activation activation; + cel::RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + ExecutionFrameBase frame(activation, options, value_manager()); + + Value value; + AttributeTrail attr; + ASSERT_OK(op->Evaluate(frame, value, attr)); + + switch (test_case.result) { + case OpResult::kTrue: + ASSERT_TRUE(InstanceOf(value)); + EXPECT_TRUE(Cast(value).NativeValue()); + break; + case OpResult::kFalse: + ASSERT_TRUE(InstanceOf(value)); + EXPECT_FALSE(Cast(value).NativeValue()); + break; + case OpResult::kUnknown: + EXPECT_TRUE(InstanceOf(value)); + break; + case OpResult::kError: + EXPECT_TRUE(InstanceOf(value)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + DirectLogicStepTest, DirectLogicStepTest, + testing::Combine(testing::Bool(), + testing::ValuesIn>({ + { + "AndFalseFalse", + Op::kAnd, + OpArg::kFalse, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndFalseTrue", + Op::kAnd, + OpArg::kFalse, + OpArg::kTrue, + OpResult::kFalse, + }, + { + "AndTrueFalse", + Op::kAnd, + OpArg::kTrue, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndTrueTrue", + Op::kAnd, + OpArg::kTrue, + OpArg::kTrue, + OpResult::kTrue, + }, + + { + "AndTrueError", + Op::kAnd, + OpArg::kTrue, + OpArg::kError, + OpResult::kError, + }, + { + "AndErrorTrue", + Op::kAnd, + OpArg::kError, + OpArg::kTrue, + OpResult::kError, + }, + { + "AndFalseError", + Op::kAnd, + OpArg::kFalse, + OpArg::kError, + OpResult::kFalse, + }, + { + "AndErrorFalse", + Op::kAnd, + OpArg::kError, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndErrorError", + Op::kAnd, + OpArg::kError, + OpArg::kError, + OpResult::kError, + }, + + { + "AndTrueUnknown", + Op::kAnd, + OpArg::kTrue, + OpArg::kUnknown, + OpResult::kUnknown, + }, + { + "AndUnknownTrue", + Op::kAnd, + OpArg::kUnknown, + OpArg::kTrue, + OpResult::kUnknown, + }, + { + "AndFalseUnknown", + Op::kAnd, + OpArg::kFalse, + OpArg::kUnknown, + OpResult::kFalse, + }, + { + "AndUnknownFalse", + Op::kAnd, + OpArg::kUnknown, + OpArg::kFalse, + OpResult::kFalse, + }, + { + "AndUnknownUnknown", + Op::kAnd, + OpArg::kUnknown, + OpArg::kUnknown, + OpResult::kUnknown, + }, + { + "AndUnknownError", + Op::kAnd, + OpArg::kUnknown, + OpArg::kError, + OpResult::kUnknown, + }, + { + "AndErrorUnknown", + Op::kAnd, + OpArg::kError, + OpArg::kUnknown, + OpResult::kUnknown, + }, + + // Or cases are simplified since the logic generalizes + // and is covered by and cases. + })), + [](const testing::TestParamInfo& info) + -> std::string { + bool shortcircuiting_enabled = std::get<0>(info.param); + absl::string_view name = std::get<1>(info.param).name; + return absl::StrCat( + name, (shortcircuiting_enabled ? "ShortcircuitingEnabled" : "")); + }); + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/regex_match_step.cc b/eval/eval/regex_match_step.cc index d8acd3a17..1d9405ceb 100644 --- a/eval/eval/regex_match_step.cc +++ b/eval/eval/regex_match_step.cc @@ -14,18 +14,36 @@ #include "eval/eval/regex_match_step.h" +#include +#include #include +#include #include #include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "common/casting.h" #include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" +#include "internal/status_macros.h" #include "re2/re2.h" namespace google::api::expr::runtime { namespace { +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::StringValue; +using ::cel::UnknownValue; +using ::cel::Value; + inline constexpr int kNumRegexMatchArguments = 1; inline constexpr size_t kRegexMatchStepSubject = 0; @@ -74,8 +92,48 @@ class RegexMatchStep final : public ExpressionStepBase { const std::shared_ptr re2_; }; +class RegexMatchDirectStep final : public DirectExpressionStep { + public: + RegexMatchDirectStep(int64_t expr_id, + std::unique_ptr subject, + std::shared_ptr re2) + : DirectExpressionStep(expr_id), + subject_(std::move(subject)), + re2_(std::move(re2)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + AttributeTrail subject_attr; + CEL_RETURN_IF_ERROR(subject_->Evaluate(frame, result, subject_attr)); + if (InstanceOf(result) || + cel::InstanceOf(result)) { + return absl::OkStatus(); + } + + if (!InstanceOf(result)) { + return absl::Status(absl::StatusCode::kInternal, + "First argument for regular " + "expression match must be a string"); + } + bool match = Cast(result).NativeValue(MatchesVisitor{*re2_}); + result = BoolValue(match); + return absl::OkStatus(); + } + + private: + std::unique_ptr subject_; + const std::shared_ptr re2_; +}; + } // namespace +std::unique_ptr CreateDirectRegexMatchStep( + int64_t expr_id, std::unique_ptr subject, + std::shared_ptr re2) { + return std::make_unique(expr_id, std::move(subject), + std::move(re2)); +} + absl::StatusOr> CreateRegexMatchStep( std::shared_ptr re2, int64_t expr_id) { return std::make_unique(expr_id, std::move(re2)); diff --git a/eval/eval/regex_match_step.h b/eval/eval/regex_match_step.h index 5ed638fbb..1d8a09118 100644 --- a/eval/eval/regex_match_step.h +++ b/eval/eval/regex_match_step.h @@ -15,14 +15,20 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_REGEX_MATCH_STEP_H_ +#include #include #include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "re2/re2.h" namespace google::api::expr::runtime { +std::unique_ptr CreateDirectRegexMatchStep( + int64_t expr_id, std::unique_ptr subject, + std::shared_ptr re2); + absl::StatusOr> CreateRegexMatchStep( std::shared_ptr re2, int64_t expr_id); diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 001dea303..aa3e3edaa 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -6,9 +6,9 @@ #include #include +#include "absl/log/absl_log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "base/kind.h" @@ -16,6 +16,8 @@ #include "common/native_type.h" #include "common/value.h" #include "common/value_manager.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" @@ -28,9 +30,12 @@ namespace google::api::expr::runtime { namespace { using ::cel::BoolValueView; +using ::cel::Cast; using ::cel::ErrorValue; +using ::cel::InstanceOf; using ::cel::MapValue; using ::cel::NullValue; +using ::cel::OptionalValue; using ::cel::ProtoWrapperTypeOptions; using ::cel::StringValue; using ::cel::StructValue; @@ -38,7 +43,6 @@ using ::cel::UnknownValue; using ::cel::Value; using ::cel::ValueKind; using ::cel::ValueView; -using ::cel::runtime_internal::CreateMissingAttributeError; // Common error for cases where evaluation attempts to perform select operations // on an unsupported type. @@ -51,27 +55,25 @@ absl::Status InvalidSelectTargetError() { } absl::optional CheckForMarkedAttributes(const AttributeTrail& trail, - ExecutionFrame* frame) { - if (frame->enable_unknowns() && - frame->attribute_utility().CheckForUnknown(trail, - /*use_partial=*/false)) { - return frame->attribute_utility().CreateUnknownSet(trail.attribute()); + ExecutionFrameBase& frame) { + if (frame.unknown_processing_enabled() && + frame.attribute_utility().CheckForUnknownExact(trail)) { + return frame.attribute_utility().CreateUnknownSet(trail.attribute()); } - if (frame->enable_missing_attribute_errors() && - frame->attribute_utility().CheckForMissingAttribute(trail)) { - auto attribute_string = trail.attribute().AsString(); - if (attribute_string.ok()) { - return frame->value_factory().CreateErrorValue( - CreateMissingAttributeError(*attribute_string)); + if (frame.missing_attribute_errors_enabled() && + frame.attribute_utility().CheckForMissingAttribute(trail)) { + auto result = frame.attribute_utility().CreateMissingAttributeError( + trail.attribute()); + + if (result.ok()) { + return std::move(result).value(); } // Invariant broken (an invalid CEL Attribute shouldn't match anything). // Log and return a CelError. - ABSL_LOG(ERROR) - << "Invalid attribute pattern matched select path: " - << attribute_string.status().ToString(); // NOLINT: OSS compatibility - return frame->value_factory().CreateErrorValue( - std::move(attribute_string).status()); + ABSL_LOG(ERROR) << "Invalid attribute pattern matched select path: " + << result.status().ToString(); // NOLINT: OSS compatibility + return frame.value_manager().CreateErrorValue(std::move(result).status()); } return absl::nullopt; @@ -92,7 +94,8 @@ ValueView TestOnlySelect(const MapValue& map, const StringValue& field_name, cel::ValueManager& value_factory, Value& scratch) { // Field presence only supports string keys containing valid identifier // characters. - auto presence = map.Has(value_factory, field_name, scratch); + absl::StatusOr presence = + map.Has(value_factory, field_name, scratch); if (!presence.ok()) { scratch = value_factory.CreateErrorValue(std::move(presence).status()); @@ -102,8 +105,6 @@ ValueView TestOnlySelect(const MapValue& map, const StringValue& field_name, return *presence; } -} // namespace - // SelectStep performs message field access specified by Expr::Select // message. class SelectStep : public ExpressionStepBase { @@ -143,7 +144,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { const Value& arg = frame->value_stack().Peek(); const AttributeTrail& trail = frame->value_stack().PeekAttribute(); - if (arg->Is() || arg->Is()) { + if (InstanceOf(arg) || InstanceOf(arg)) { // Bubble up unknowns and errors. return absl::OkStatus(); } @@ -181,7 +182,7 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { } absl::optional marked_attribute_check = - CheckForMarkedAttributes(result_trail, frame); + CheckForMarkedAttributes(result_trail, *frame); if (marked_attribute_check.has_value()) { frame->value_stack().PopAndPush(std::move(marked_attribute_check).value(), std::move(result_trail)); @@ -291,6 +292,204 @@ absl::StatusOr> SelectStep::PerformSelect( } } +class DirectSelectStep : public DirectExpressionStep { + public: + DirectSelectStep(int64_t expr_id, + std::unique_ptr operand, + StringValue field, bool test_only, + bool enable_wrapper_type_null_unboxing, + bool enable_optional_types) + : DirectExpressionStep(expr_id), + operand_(std::move(operand)), + field_value_(std::move(field)), + field_(field_value_.ToString()), + test_only_(test_only), + unboxing_option_(enable_wrapper_type_null_unboxing + ? ProtoWrapperTypeOptions::kUnsetNull + : ProtoWrapperTypeOptions::kUnsetProtoDefault), + enable_optional_types_(enable_optional_types) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override { + CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute)); + + if (InstanceOf(result) || InstanceOf(result)) { + // Just forward. + return absl::OkStatus(); + } + + if (frame.attribute_tracking_enabled()) { + attribute = attribute.Step(&field_); + absl::optional value = CheckForMarkedAttributes(attribute, frame); + if (value.has_value()) { + result = std::move(value).value(); + return absl::OkStatus(); + } + } + + const cel::OptionalValueInterface* optional_arg = nullptr; + + if (enable_optional_types_ && + cel::NativeTypeId::Of(result) == + cel::NativeTypeId::For()) { + optional_arg = + cel::internal::down_cast( + cel::Cast(result).operator->()); + } + + switch (result.kind()) { + case ValueKind::kStruct: + case ValueKind::kMap: + break; + case ValueKind::kNull: + result = frame.value_manager().CreateErrorValue( + cel::runtime_internal::CreateError("Message is NULL")); + return absl::OkStatus(); + default: + if (optional_arg != nullptr) { + break; + } + result = + frame.value_manager().CreateErrorValue(InvalidSelectTargetError()); + return absl::OkStatus(); + } + + Value scratch; + if (test_only_) { + if (optional_arg != nullptr) { + if (!optional_arg->HasValue()) { + result = cel::BoolValue{false}; + return absl::OkStatus(); + } + result = PerformTestOnlySelect(frame, optional_arg->Value(), scratch); + return absl::OkStatus(); + } + result = PerformTestOnlySelect(frame, result, scratch); + return absl::OkStatus(); + } + + if (optional_arg != nullptr) { + if (!optional_arg->HasValue()) { + // result is still buffer for the container. just return. + return absl::OkStatus(); + } + CEL_ASSIGN_OR_RETURN( + result, PerformOptionalSelect(frame, optional_arg->Value(), scratch)); + + return absl::OkStatus(); + } + + CEL_ASSIGN_OR_RETURN(result, PerformSelect(frame, result, scratch)); + return absl::OkStatus(); + } + + private: + std::unique_ptr operand_; + + ValueView PerformTestOnlySelect(ExecutionFrameBase& frame, const Value& value, + Value& scratch) const; + absl::StatusOr PerformOptionalSelect(ExecutionFrameBase& frame, + const Value& value, + Value& scratch) const; + absl::StatusOr PerformSelect(ExecutionFrameBase& frame, + const Value& value, + Value& scratch) const; + + // Field name in formats supported by each of the map and struct field access + // APIs. + // + // ToString or ValueManager::CreateString may force a copy so we do this at + // plan time. + StringValue field_value_; + std::string field_; + + // whether this is a has() expression. + bool test_only_; + ProtoWrapperTypeOptions unboxing_option_; + bool enable_optional_types_; +}; + +ValueView DirectSelectStep::PerformTestOnlySelect(ExecutionFrameBase& frame, + const cel::Value& value, + Value& scratch) const { + switch (value.kind()) { + case ValueKind::kMap: + return TestOnlySelect(Cast(value), field_value_, + frame.value_manager(), scratch); + case ValueKind::kMessage: + return TestOnlySelect(Cast(value), field_, + frame.value_manager(), scratch); + default: + // Control flow should have returned earlier. + scratch = + frame.value_manager().CreateErrorValue(InvalidSelectTargetError()); + return ValueView{scratch}; + } +} + +absl::StatusOr DirectSelectStep::PerformOptionalSelect( + ExecutionFrameBase& frame, const Value& value, Value& scratch) const { + switch (value.kind()) { + case ValueKind::kStruct: { + auto struct_value = Cast(value); + CEL_ASSIGN_OR_RETURN(auto ok, struct_value.HasFieldByName(field_)); + if (!ok) { + scratch = OptionalValue::None(); + return ValueView{scratch}; + } + CEL_ASSIGN_OR_RETURN(auto result, struct_value.GetFieldByName( + frame.value_manager(), field_, + scratch, unboxing_option_)); + scratch = OptionalValue::Of(frame.value_manager().GetMemoryManager(), + Value(result)); + return ValueView{scratch}; + } + case ValueKind::kMap: { + CEL_ASSIGN_OR_RETURN(auto lookup, + Cast(value).Find(frame.value_manager(), + field_value_, scratch)); + if (!lookup.second) { + scratch = OptionalValue::None(); + return ValueView{scratch}; + } + scratch = OptionalValue::Of(frame.value_manager().GetMemoryManager(), + Value(lookup.first)); + return ValueView{scratch}; + } + default: + // Control flow should have returned earlier. + return InvalidSelectTargetError(); + } +} + +absl::StatusOr DirectSelectStep::PerformSelect( + ExecutionFrameBase& frame, const cel::Value& value, Value& scratch) const { + switch (value.kind()) { + case ValueKind::kStruct: { + return Cast(value).GetFieldByName( + frame.value_manager(), field_, scratch, unboxing_option_); + } + case ValueKind::kMap: { + return Cast(value).Get(frame.value_manager(), field_value_, + scratch); + } + default: + // Control flow should have returned earlier. + return InvalidSelectTargetError(); + } +} + +} // namespace + +std::unique_ptr CreateDirectSelectStep( + std::unique_ptr operand, StringValue field, + bool test_only, int64_t expr_id, bool enable_wrapper_type_null_unboxing, + bool enable_optional_types) { + return std::make_unique( + expr_id, std::move(operand), std::move(field), test_only, + enable_wrapper_type_null_unboxing, enable_optional_types); +} + // Factory method for Select - based Execution step absl::StatusOr> CreateSelectStep( const cel::ast_internal::Select& select_expr, int64_t expr_id, diff --git a/eval/eval/select_step.h b/eval/eval/select_step.h index 0e99e78f1..5f2ef7c68 100644 --- a/eval/eval/select_step.h +++ b/eval/eval/select_step.h @@ -5,13 +5,20 @@ #include #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "base/ast_internal/expr.h" +#include "common/value.h" #include "common/value_manager.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +// Factory method for recursively evaluated select step. +std::unique_ptr CreateDirectSelectStep( + std::unique_ptr operand, cel::StringValue field, + bool test_only, int64_t expr_id, bool enable_wrapper_type_null_unboxing, + bool enable_optional_types = false); + // Factory method for Select - based Execution step absl::StatusOr> CreateSelectStep( const cel::ast_internal::Select& select_expr, int64_t expr_id, diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 0cc361333..a3dd839c0 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -6,15 +6,22 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/wrappers.pb.h" +#include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "base/ast_internal/expr.h" +#include "base/attribute.h" +#include "base/attribute_set.h" #include "base/type_provider.h" -#include "common/type_factory.h" -#include "common/type_manager.h" +#include "common/legacy_value.h" +#include "common/value.h" #include "common/value_manager.h" +#include "common/value_testing.h" #include "common/values/legacy_value_manager.h" +#include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" @@ -28,26 +35,44 @@ #include "eval/testutil/test_extensions.pb.h" #include "eval/testutil/test_message.pb.h" #include "extensions/protobuf/memory_manager.h" +#include "internal/proto_matchers.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" -#include "testutil/util.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" namespace google::api::expr::runtime { namespace { +using ::cel::Attribute; +using ::cel::AttributeQualifier; +using ::cel::AttributeSet; +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::ManagedValueFactory; +using ::cel::OptionalValue; +using ::cel::RuntimeOptions; using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::Value; using ::cel::ast_internal::Expr; using ::cel::extensions::ProtoMemoryManagerRef; +using ::cel::internal::test::EqualsProto; +using ::cel::test::IntValueIs; +using ::google::api::expr::test::v1::proto3::TestAllTypes; using testing::_; using testing::Eq; using testing::HasSubstr; using testing::Return; +using testing::UnorderedElementsAre; using cel::internal::StatusIs; -using testutil::EqualsProto; - struct RunExpressionOptions { bool enable_unknowns = false; bool enable_wrapper_type_null_unboxing = false; @@ -1024,6 +1049,424 @@ TEST_F(SelectStepTest, UnknownPatternResolvesToUnknown) { INSTANTIATE_TEST_SUITE_P(UnknownsEnabled, SelectStepConformanceTest, testing::Bool()); +class DirectSelectStepTest : public testing::Test { + public: + DirectSelectStepTest() + : value_manager_(TypeProvider::Builtin(), + ProtoMemoryManagerRef(&arena_)) {} + + cel::Value TestWrapMessage(const google::protobuf::Message* message) { + CelValue value = CelProtoWrapper::CreateMessage(message, &arena_); + auto result = cel::interop_internal::FromLegacyValue(&arena_, value); + ABSL_DCHECK_OK(result.status()); + return std::move(result).value(); + } + + std::vector AttributeStrings(const UnknownValue& v) { + std::vector result; + for (const Attribute& attr : v.attribute_set()) { + auto attr_str = attr.AsString(); + ABSL_DCHECK_OK(attr_str.status()); + result.push_back(std::move(attr_str).value()); + } + return result; + } + + protected: + google::protobuf::Arena arena_; + ManagedValueFactory value_manager_; +}; + +TEST_F(DirectSelectStepTest, SelectFromMap) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), + value_manager_.get().CreateUncheckedStringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + ASSERT_OK_AND_ASSIGN(auto map_builder, + value_manager_.get().NewMapValueBuilder( + value_manager_.get().GetDynDynMapType())); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("one"), IntValue(1))); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("two"), IntValue(2))); + activation.InsertOrAssignValue("map_val", std::move(*map_builder).Build()); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_EQ(Cast(result).NativeValue(), 1); +} + +TEST_F(DirectSelectStepTest, HasMap) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), + value_manager_.get().CreateUncheckedStringValue("two"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + ASSERT_OK_AND_ASSIGN(auto map_builder, + value_manager_.get().NewMapValueBuilder( + value_manager_.get().GetDynDynMapType())); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("one"), IntValue(1))); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("two"), IntValue(2))); + activation.InsertOrAssignValue("map_val", std::move(*map_builder).Build()); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_TRUE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), + value_manager_.get().CreateUncheckedStringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + ASSERT_OK_AND_ASSIGN(auto map_builder, + value_manager_.get().NewMapValueBuilder( + value_manager_.get().GetDynDynMapType())); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("one"), IntValue(1))); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("two"), IntValue(2))); + activation.InsertOrAssignValue( + "map_val", OptionalValue::Of(value_manager_.get().GetMemoryManager(), + std::move(*map_builder).Build())); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(static_cast(result)).Value(), + IntValueIs(1)); +} + +TEST_F(DirectSelectStepTest, SelectFromEmptyOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), + value_manager_.get().CreateUncheckedStringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + activation.InsertOrAssignValue("map_val", OptionalValue::None()); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_FALSE( + cel::Cast(static_cast(result)).HasValue()); +} + +TEST_F(DirectSelectStepTest, HasOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), + value_manager_.get().CreateUncheckedStringValue("two"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + ASSERT_OK_AND_ASSIGN(auto map_builder, + value_manager_.get().NewMapValueBuilder( + value_manager_.get().GetDynDynMapType())); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("one"), IntValue(1))); + ASSERT_OK(map_builder->Put( + value_manager_.get().CreateUncheckedStringValue("two"), IntValue(2))); + activation.InsertOrAssignValue( + "map_val", OptionalValue::Of(value_manager_.get().GetMemoryManager(), + std::move(*map_builder).Build())); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_TRUE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, HasEmptyOptional) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("map_val", -1), + value_manager_.get().CreateUncheckedStringValue("two"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true, + /*enable_optional_types=*/true); + + activation.InsertOrAssignValue("map_val", OptionalValue::None()); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_FALSE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromStruct) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("test_all_types", -1), + value_manager_.get().CreateUncheckedStringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_EQ(Cast(result).NativeValue(), 1); +} + +TEST_F(DirectSelectStepTest, HasStruct) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("test_all_types", -1), + value_manager_.get().CreateUncheckedStringValue("single_string"), + /*test_only=*/true, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + + // has(test_all_types.single_string) + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_FALSE(Cast(result).NativeValue()); +} + +TEST_F(DirectSelectStepTest, SelectFromUnsupportedType) { + cel::Activation activation; + RuntimeOptions options; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("bool_val", -1), + value_manager_.get().CreateUncheckedStringValue("one"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + activation.InsertOrAssignValue("bool_val", BoolValue(false)); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Applying SELECT to non-message type"))); +} + +TEST_F(DirectSelectStepTest, AttributeUpdatedIfRequested) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("test_all_types", -1), + value_manager_.get().CreateUncheckedStringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).NativeValue(), 1); + + ASSERT_OK_AND_ASSIGN(std::string attr_str, attr.attribute().AsString()); + EXPECT_EQ(attr_str, "test_all_types.single_int64"); +} + +TEST_F(DirectSelectStepTest, MissingAttributesToErrors) { + cel::Activation activation; + RuntimeOptions options; + options.enable_missing_attribute_errors = true; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("test_all_types", -1), + value_manager_.get().CreateUncheckedStringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + activation.SetMissingPatterns({cel::AttributePattern( + "test_all_types", + {cel::AttributeQualifierPattern::OfString("single_int64")})}); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("test_all_types.single_int64"))); +} + +TEST_F(DirectSelectStepTest, IdentifiesUnknowns) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + auto step = CreateDirectSelectStep( + CreateDirectIdentStep("test_all_types", -1), + value_manager_.get().CreateUncheckedStringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + activation.SetUnknownPatterns({cel::AttributePattern( + "test_all_types", + {cel::AttributeQualifierPattern::OfString("single_int64")})}); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + + EXPECT_THAT(AttributeStrings(Cast(result)), + UnorderedElementsAre("test_all_types.single_int64")); +} + +TEST_F(DirectSelectStepTest, ForwardErrorValue) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + auto step = CreateDirectSelectStep( + CreateConstValueDirectStep( + value_manager_.get().CreateErrorValue(absl::InternalError("test1")), + -1), + value_manager_.get().CreateUncheckedStringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, HasSubstr("test1"))); +} + +TEST_F(DirectSelectStepTest, ForwardUnknownOperand) { + cel::Activation activation; + RuntimeOptions options; + options.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + + AttributeSet attr_set({Attribute("attr", {AttributeQualifier::OfInt(0)})}); + auto step = CreateDirectSelectStep( + CreateConstValueDirectStep( + value_manager_.get().CreateUnknownValue(std::move(attr_set)), -1), + value_manager_.get().CreateUncheckedStringValue("single_int64"), + /*test_only=*/false, -1, + /*enable_wrapper_type_null_unboxing=*/true); + + TestAllTypes message; + message.set_single_int64(1); + activation.InsertOrAssignValue("test_all_types", TestWrapMessage(&message)); + + ExecutionFrameBase frame(activation, options, value_manager_.get()); + + Value result; + AttributeTrail attr; + ASSERT_OK(step->Evaluate(frame, result, attr)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(AttributeStrings(Cast(result)), + UnorderedElementsAre("attr[0]")); +} + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/shadowable_value_step.cc b/eval/eval/shadowable_value_step.cc index d359b9421..00693ea21 100644 --- a/eval/eval/shadowable_value_step.cc +++ b/eval/eval/shadowable_value_step.cc @@ -5,7 +5,13 @@ #include #include +#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "internal/status_macros.h" @@ -13,6 +19,8 @@ namespace google::api::expr::runtime { namespace { +using ::cel::Value; + class ShadowableValueStep : public ExpressionStepBase { public: ShadowableValueStep(std::string identifier, cel::Value value, int64_t expr_id) @@ -24,7 +32,7 @@ class ShadowableValueStep : public ExpressionStepBase { private: std::string identifier_; - cel::Value value_; + Value value_; }; absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { @@ -40,6 +48,39 @@ absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { return absl::OkStatus(); } +class DirectShadowableValueStep : public DirectExpressionStep { + public: + DirectShadowableValueStep(std::string identifier, cel::Value value, + int64_t expr_id) + : DirectExpressionStep(expr_id), + identifier_(std::move(identifier)), + value_(std::move(value)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; + + private: + std::string identifier_; + Value value_; +}; + +// TODO(uncreated-issue/67): Attribute tracking is skipped for the shadowed case. May +// cause problems for users with unknown tracking and variables named like +// 'list' etc, but follows the current behavior of the stack machine version. +absl::Status DirectShadowableValueStep::Evaluate( + ExecutionFrameBase& frame, Value& result, AttributeTrail& attribute) const { + cel::Value scratch; + CEL_ASSIGN_OR_RETURN(auto var, + frame.activation().FindVariable(frame.value_manager(), + identifier_, scratch)); + if (var.has_value()) { + result = *var; + } else { + result = value_; + } + return absl::OkStatus(); +} + } // namespace absl::StatusOr> CreateShadowableValueStep( @@ -48,4 +89,10 @@ absl::StatusOr> CreateShadowableValueStep( std::move(value), expr_id); } +std::unique_ptr CreateDirectShadowableValueStep( + std::string identifier, cel::Value value, int64_t expr_id) { + return std::make_unique(std::move(identifier), + std::move(value), expr_id); +} + } // namespace google::api::expr::runtime diff --git a/eval/eval/shadowable_value_step.h b/eval/eval/shadowable_value_step.h index 4f00672a1..21c6753d5 100644 --- a/eval/eval/shadowable_value_step.h +++ b/eval/eval/shadowable_value_step.h @@ -7,6 +7,7 @@ #include "absl/status/statusor.h" #include "common/value.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { @@ -17,6 +18,9 @@ namespace google::api::expr::runtime { absl::StatusOr> CreateShadowableValueStep( std::string identifier, cel::Value value, int64_t expr_id); +std::unique_ptr CreateDirectShadowableValueStep( + std::string identifier, cel::Value value, int64_t expr_id); + } // namespace google::api::expr::runtime #endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index f41291fb2..3a2850823 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -1,19 +1,31 @@ #include "eval/eval/ternary_step.h" +#include #include #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "base/builtins.h" +#include "common/casting.h" #include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/internal/errors.h" +#include "internal/status_macros.h" namespace google::api::expr::runtime { namespace { +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::UnknownValue; using ::cel::builtin::kTernary; using ::cel::runtime_internal::CreateNoMatchingOverloadError; @@ -21,6 +33,104 @@ inline constexpr size_t kTernaryStepCondition = 0; inline constexpr size_t kTernaryStepTrue = 1; inline constexpr size_t kTernaryStepFalse = 2; +class ExhaustiveDirectTernaryStep : public DirectExpressionStep { + public: + ExhaustiveDirectTernaryStep(std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, + int64_t expr_id) + : DirectExpressionStep(expr_id), + condition_(std::move(condition)), + left_(std::move(left)), + right_(std::move(right)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const override { + cel::Value condition; + cel::Value lhs; + cel::Value rhs; + + AttributeTrail condition_attr; + AttributeTrail lhs_attr; + AttributeTrail rhs_attr; + + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + CEL_RETURN_IF_ERROR(left_->Evaluate(frame, lhs, lhs_attr)); + CEL_RETURN_IF_ERROR(right_->Evaluate(frame, rhs, rhs_attr)); + + if (InstanceOf(condition) || + InstanceOf(condition)) { + result = std::move(condition); + attribute = std::move(condition_attr); + return absl::OkStatus(); + } + + if (!InstanceOf(condition)) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError(kTernary)); + return absl::OkStatus(); + } + + if (Cast(condition).NativeValue()) { + result = std::move(lhs); + attribute = std::move(lhs_attr); + } else { + result = std::move(rhs); + attribute = std::move(rhs_attr); + } + return absl::OkStatus(); + } + + private: + std::unique_ptr condition_; + std::unique_ptr left_; + std::unique_ptr right_; +}; + +class ShortcircuitingDirectTernaryStep : public DirectExpressionStep { + public: + ShortcircuitingDirectTernaryStep( + std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, int64_t expr_id) + : DirectExpressionStep(expr_id), + condition_(std::move(condition)), + left_(std::move(left)), + right_(std::move(right)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& attribute) const override { + cel::Value condition; + + AttributeTrail condition_attr; + + CEL_RETURN_IF_ERROR(condition_->Evaluate(frame, condition, condition_attr)); + + if (InstanceOf(condition) || + InstanceOf(condition)) { + result = std::move(condition); + attribute = std::move(condition_attr); + return absl::OkStatus(); + } + + if (!InstanceOf(condition)) { + result = frame.value_manager().CreateErrorValue( + CreateNoMatchingOverloadError(kTernary)); + return absl::OkStatus(); + } + + if (Cast(condition).NativeValue()) { + return left_->Evaluate(frame, result, attribute); + } + return right_->Evaluate(frame, result, attribute); + } + + private: + std::unique_ptr condition_; + std::unique_ptr left_; + std::unique_ptr right_; +}; + class TernaryStep : public ExpressionStepBase { public: // Constructs FunctionStep that uses overloads specified. @@ -72,6 +182,21 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { } // namespace +// Factory method for ternary (_?_:_) recursive execution step +std::unique_ptr CreateDirectTernaryStep( + std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, int64_t expr_id, + bool shortcircuiting) { + if (shortcircuiting) { + return std::make_unique( + std::move(condition), std::move(left), std::move(right), expr_id); + } + + return std::make_unique( + std::move(condition), std::move(left), std::move(right), expr_id); +} + absl::StatusOr> CreateTernaryStep( int64_t expr_id) { return std::make_unique(expr_id); diff --git a/eval/eval/ternary_step.h b/eval/eval/ternary_step.h index de43a03d0..2b51e95ea 100644 --- a/eval/eval/ternary_step.h +++ b/eval/eval/ternary_step.h @@ -2,12 +2,21 @@ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TERNARY_STEP_H_ #include +#include #include "absl/status/statusor.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" namespace google::api::expr::runtime { +// Factory method for ternary (_?_:_) recursive execution step +std::unique_ptr CreateDirectTernaryStep( + std::unique_ptr condition, + std::unique_ptr left, + std::unique_ptr right, int64_t expr_id, + bool shortcircuiting = true); + // Factory method for ternary (_?_:_) execution step absl::StatusOr> CreateTernaryStep( int64_t expr_id); diff --git a/eval/eval/ternary_step_test.cc b/eval/eval/ternary_step_test.cc index 5e7da72fa..467d62c17 100644 --- a/eval/eval/ternary_step_test.cc +++ b/eval/eval/ternary_step_test.cc @@ -1,28 +1,58 @@ #include "eval/eval/ternary_step.h" +#include #include #include +#include -#include "google/protobuf/descriptor.h" +#include "absl/base/nullability.h" +#include "absl/status/status.h" +#include "base/ast_internal/expr.h" +#include "base/attribute.h" +#include "base/attribute_set.h" #include "base/type_provider.h" +#include "common/casting.h" +#include "common/value.h" +#include "common/value_manager.h" +#include "eval/eval/attribute_trail.h" #include "eval/eval/cel_expression_flat_impl.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/ident_step.h" #include "eval/public/activation.h" +#include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_set.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "runtime/activation.h" +#include "runtime/managed_value_factory.h" #include "runtime/runtime_options.h" +#include "google/protobuf/arena.h" namespace google::api::expr::runtime { namespace { +using ::cel::BoolValue; +using ::cel::Cast; +using ::cel::ErrorValue; +using ::cel::InstanceOf; +using ::cel::IntValue; +using ::cel::RuntimeOptions; using ::cel::TypeProvider; +using ::cel::UnknownValue; +using ::cel::ValueManager; using ::cel::ast_internal::Expr; +using ::cel::extensions::ProtoMemoryManagerRef; using ::google::protobuf::Arena; +using testing::ElementsAre; using testing::Eq; +using testing::HasSubstr; +using testing::Truly; +using cel::internal::StatusIs; class LogicStepTest : public testing::TestWithParam { public: @@ -173,6 +203,174 @@ TEST_F(LogicStepTest, TestUnknownHandling) { } INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); + +class TernaryStepDirectTest : public testing::TestWithParam { + public: + TernaryStepDirectTest() + : value_factory_(TypeProvider::Builtin(), + ProtoMemoryManagerRef(&arena_)) {} + + bool Shortcircuiting() { return GetParam(); } + + ValueManager& value_manager() { return value_factory_.get(); } + + protected: + Arena arena_; + cel::ManagedValueFactory value_factory_; +}; + +TEST_P(TernaryStepDirectTest, ReturnLhs) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, value_manager()); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(BoolValue(true), -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).NativeValue(), 1); +} + +TEST_P(TernaryStepDirectTest, ReturnRhs) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, value_manager()); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(BoolValue(false), -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_EQ(Cast(result).NativeValue(), 2); +} + +TEST_P(TernaryStepDirectTest, ForwardError) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, value_manager()); + + cel::Value error_value = + value_manager().CreateErrorValue(absl::InternalError("test error")); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(error_value, -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kInternal, "test error")); +} + +TEST_P(TernaryStepDirectTest, ForwardUnknown) { + cel::Activation activation; + RuntimeOptions opts; + opts.unknown_processing = cel::UnknownProcessingOptions::kAttributeOnly; + ExecutionFrameBase frame(activation, opts, value_manager()); + + std::vector attrs{{cel::Attribute("var")}}; + + cel::UnknownValue unknown_value = + value_manager().CreateUnknownValue(cel::AttributeSet(attrs)); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(unknown_value, -1), + CreateConstValueDirectStep(IntValue(2), -1), + CreateConstValueDirectStep(IntValue(3), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue().unknown_attributes(), + ElementsAre(Truly([](const cel::Attribute& attr) { + return attr.variable_name() == "var"; + }))); +} + +TEST_P(TernaryStepDirectTest, UnexpectedCondtionKind) { + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, value_manager()); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(IntValue(-1), -1), + CreateConstValueDirectStep(IntValue(1), -1), + CreateConstValueDirectStep(IntValue(2), -1), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), + StatusIs(absl::StatusCode::kUnknown, + HasSubstr("No matching overloads found"))); +} + +TEST_P(TernaryStepDirectTest, Shortcircuiting) { + class RecordCallStep : public DirectExpressionStep { + public: + explicit RecordCallStep(bool& was_called) + : DirectExpressionStep(-1), was_called_(&was_called) {} + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& trail) const override { + *was_called_ = true; + result = IntValue(1); + return absl::OkStatus(); + } + + private: + absl::Nonnull was_called_; + }; + + bool lhs_was_called = false; + bool rhs_was_called = false; + + cel::Activation activation; + RuntimeOptions opts; + ExecutionFrameBase frame(activation, opts, value_manager()); + + std::unique_ptr step = CreateDirectTernaryStep( + CreateConstValueDirectStep(BoolValue(false), -1), + std::make_unique(lhs_was_called), + std::make_unique(rhs_was_called), -1, Shortcircuiting()); + + cel::Value result; + AttributeTrail attr_unused; + + ASSERT_OK(step->Evaluate(frame, result, attr_unused)); + + ASSERT_TRUE(InstanceOf(result)); + EXPECT_THAT(Cast(result).NativeValue(), Eq(1)); + bool expect_eager_eval = !Shortcircuiting(); + EXPECT_EQ(lhs_was_called, expect_eager_eval); + EXPECT_TRUE(rhs_was_called); +} + +INSTANTIATE_TEST_SUITE_P(TernaryStepDirectTest, TernaryStepDirectTest, + testing::Bool()); + } // namespace } // namespace google::api::expr::runtime diff --git a/eval/eval/trace_step.h b/eval/eval/trace_step.h new file mode 100644 index 000000000..fa14dfbcc --- /dev/null +++ b/eval/eval/trace_step.h @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "common/native_type.h" +#include "common/value.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" +#include "eval/eval/evaluator_core.h" +#include "internal/status_macros.h" +namespace google::api::expr::runtime { + +// A decorator that implements tracing for recursively evaluated CEL +// expressions. +// +// Allows inspection for extensions to extract the wrapped expression. +class TraceStep : public DirectExpressionStep { + public: + explicit TraceStep(std::unique_ptr expression) + : DirectExpressionStep(-1), expression_(std::move(expression)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, cel::Value& result, + AttributeTrail& trail) const override { + CEL_RETURN_IF_ERROR(expression_->Evaluate(frame, result, trail)); + if (!frame.callback()) { + return absl::OkStatus(); + } + return frame.callback()(expression_->expr_id(), result, + frame.value_manager()); + } + + cel::NativeTypeId GetNativeTypeId() const override { + return cel::NativeTypeId::For(); + } + + absl::optional> GetDependencies() + const override { + return {{expression_.get()}}; + } + + absl::optional>> + ExtractDependencies() override { + std::vector> dependencies; + dependencies.push_back(std::move(expression_)); + return dependencies; + }; + + private: + std::unique_ptr expression_; +}; + +} // namespace google::api::expr::runtime + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_TRACE_STEP_H_ diff --git a/eval/public/cel_options.cc b/eval/public/cel_options.cc index d22434eaf..65b6517e7 100644 --- a/eval/public/cel_options.cc +++ b/eval/public/cel_options.cc @@ -37,7 +37,9 @@ cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) { options.enable_qualified_type_identifiers, options.enable_heterogeneous_equality, options.enable_empty_wrapper_null_unboxing, - options.enable_lazy_bind_initialization}; + options.enable_lazy_bind_initialization, + options.max_recursion_depth, + options.enable_recursive_tracing}; } } // namespace google::api::expr::runtime diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index e45f10eee..a9e1dbb67 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -164,6 +164,17 @@ struct InterpreterOptions { // This is now always enabled. Setting this option has no effect. It will be // removed in a later update. bool enable_lazy_bind_initialization = true; + + // Maximum recursion depth for evaluable programs. + // + // -1 means unbounded. + int max_recursion_depth = -1; // DO NOT SUBMIT + + // Enable tracing support for recursively planned programs. + // + // Unlike the stack machine implementation, supporting tracing can affect + // performance whether or not tracing is requested for a given evaluation. + bool enable_recursive_tracing = true; }; // LINT.ThenChange(//depot/google3/runtime/runtime_options.h) diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 8f4bb1536..be0f98b1f 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -65,6 +65,14 @@ cc_test( ], ) +cc_test( + name = "recursive_benchmark_test", + size = "small", + args = ["--enable_recursive_planning"], + tags = ["benchmark"], + deps = [":benchmark_testlib"], +) + cc_test( name = "allocation_benchmark_test", size = "small", diff --git a/eval/tests/benchmark_test.cc b/eval/tests/benchmark_test.cc index ce03c3dcf..53266f4eb 100644 --- a/eval/tests/benchmark_test.cc +++ b/eval/tests/benchmark_test.cc @@ -30,6 +30,7 @@ #include "google/protobuf/arena.h" ABSL_FLAG(bool, enable_optimizations, false, "enable const folding opt"); +ABSL_FLAG(bool, enable_recursive_planning, false, "enable recursive planning"); namespace google { namespace api { @@ -51,6 +52,10 @@ InterpreterOptions GetOptions(google::protobuf::Arena& arena) { options.constant_folding = true; } + if (absl::GetFlag(FLAGS_enable_recursive_planning)) { + options.max_recursion_depth = -1; + } + return options; } @@ -105,6 +110,7 @@ absl::Status EmptyCallback(int64_t expr_id, const CelValue& value, static void BM_Eval_Trace(benchmark::State& state) { google::protobuf::Arena arena; InterpreterOptions options = GetOptions(arena); + options.enable_recursive_tracing = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); @@ -189,6 +195,7 @@ BENCHMARK(BM_EvalString)->Range(1, 10000); static void BM_EvalString_Trace(benchmark::State& state) { google::protobuf::Arena arena; InterpreterOptions options = GetOptions(arena); + options.enable_recursive_tracing = true; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); @@ -408,7 +415,7 @@ comprehension_expr: < iter_range: < id: 2 ident_expr: < - name: "list" + name: "list_var" > > accu_init: < @@ -463,7 +470,7 @@ void BM_Comprehension(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; @@ -496,8 +503,10 @@ void BM_Comprehension_Trace(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); + options.enable_recursive_tracing = true; + options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); @@ -756,7 +765,7 @@ comprehension_expr: < iter_range: < id: 2 ident_expr: < - name: "list" + name: "list_var" > > accu_init: < @@ -783,7 +792,7 @@ comprehension_expr: < iter_range: < id: 9 ident_expr: < - name: "list" + name: "list_var" > > accu_init: < @@ -854,7 +863,7 @@ void BM_NestedComprehension(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; auto builder = CreateCelExpressionBuilder(options); @@ -887,10 +896,12 @@ void BM_NestedComprehension_Trace(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; + options.enable_recursive_tracing = true; + auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN(auto cel_expr, @@ -910,7 +921,7 @@ void BM_ListComprehension(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, - parser::Parse("list.map(x, x * 2)")); + parser::Parse("list_var.map(x, x * 2)")); int len = state.range(0); std::vector list; @@ -920,7 +931,7 @@ void BM_ListComprehension(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; @@ -943,7 +954,7 @@ void BM_ListComprehension_Trace(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, - parser::Parse("list.map(x, x * 2)")); + parser::Parse("list_var.map(x, x * 2)")); int len = state.range(0); std::vector list; @@ -953,10 +964,12 @@ void BM_ListComprehension_Trace(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options = GetOptions(arena); options.comprehension_max_iterations = 10000000; options.enable_comprehension_list_append = true; + options.enable_recursive_tracing = true; + auto builder = CreateCelExpressionBuilder(options); ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); ASSERT_OK_AND_ASSIGN( @@ -976,7 +989,7 @@ void BM_ListComprehension_Opt(benchmark::State& state) { google::protobuf::Arena arena; Activation activation; ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, - parser::Parse("list.map(x, x * 2)")); + parser::Parse("list_var.map(x, x * 2)")); int len = state.range(0); std::vector list; @@ -986,7 +999,7 @@ void BM_ListComprehension_Opt(benchmark::State& state) { } ContainerBackedListImpl cel_list(std::move(list)); - activation.InsertValue("list", CelValue::CreateList(&cel_list)); + activation.InsertValue("list_var", CelValue::CreateList(&cel_list)); InterpreterOptions options; options.constant_arena = &arena; options.constant_folding = true; diff --git a/extensions/BUILD b/extensions/BUILD index 56b49c3f6..a76620131 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -214,12 +214,14 @@ cc_library( "//base:builtins", "//base/ast_internal:ast_impl", "//base/ast_internal:expr", + "//common:casting", "//common:kind", "//common:memory", "//common:type", "//common:value", "//eval/compiler:flat_expr_builder_extensions", "//eval/eval:attribute_trail", + "//eval/eval:direct_expression_step", "//eval/eval:evaluator_core", "//eval/eval:expression_step_base", "//eval/public:ast_rewrite_native", diff --git a/extensions/bindings_ext_test.cc b/extensions/bindings_ext_test.cc index 7aa7ad0ab..3b3a645a6 100644 --- a/extensions/bindings_ext_test.cc +++ b/extensions/bindings_ext_test.cc @@ -90,10 +90,11 @@ std::unique_ptr CreateBindFunction() { } class BindingsExtTest - : public testing::TestWithParam> { + : public testing::TestWithParam> { protected: const TestInfo& GetTestInfo() { return std::get<0>(GetParam()); } bool GetEnableConstantFolding() { return std::get<1>(GetParam()); } + bool GetEnableRecursivePlan() { return std::get<2>(GetParam()); } }; TEST_P(BindingsExtTest, EndToEnd) { @@ -121,6 +122,7 @@ TEST_P(BindingsExtTest, EndToEnd) { options.enable_empty_wrapper_null_unboxing = true; options.constant_folding = GetEnableConstantFolding(); options.constant_arena = &arena; + options.max_recursion_depth = GetEnableRecursivePlan() ? -1 : 0; std::unique_ptr builder = CreateCelExpressionBuilder(options); ASSERT_OK(builder->GetRegistry()->Register(CreateBindFunction())); @@ -169,7 +171,8 @@ INSTANTIATE_TEST_SUITE_P( // Error case where the variable name is not a simple identifier. {"cel.bind(bad.name, true, bad.name)", "variable name must be a simple identifier"}}), - /*constant_folding*/ testing::Bool())); + /*constant_folding*/ testing::Bool(), + /*recursive_plan*/ testing::Bool())); // Test bind expression with nested field selection. // diff --git a/extensions/select_optimization.cc b/extensions/select_optimization.cc index 5c4763a73..48ad883f7 100644 --- a/extensions/select_optimization.cc +++ b/extensions/select_optimization.cc @@ -38,6 +38,7 @@ #include "base/ast_internal/expr.h" #include "base/attribute.h" #include "base/builtins.h" +#include "common/casting.h" #include "common/kind.h" #include "common/memory.h" #include "common/type.h" @@ -46,6 +47,7 @@ #include "common/value_manager.h" #include "eval/compiler/flat_expr_builder_extensions.h" #include "eval/eval/attribute_trail.h" +#include "eval/eval/direct_expression_step.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/public/ast_rewrite_native.h" @@ -66,7 +68,9 @@ using ::cel::ast_internal::ExprKind; using ::cel::ast_internal::Select; using ::cel::ast_internal::SourcePosition; using ::google::api::expr::runtime::AttributeTrail; +using ::google::api::expr::runtime::DirectExpressionStep; using ::google::api::expr::runtime::ExecutionFrame; +using ::google::api::expr::runtime::ExecutionFrameBase; using ::google::api::expr::runtime::ExpressionStepBase; using ::google::api::expr::runtime::PlannerContext; using ::google::api::expr::runtime::ProgramOptimizer; @@ -528,15 +532,12 @@ class RewriterImpl : public AstRewriterBase { absl::Status progress_status_; }; -class OptimizedSelectStep : public ExpressionStepBase { +class OptimizedSelectImpl { public: - OptimizedSelectStep( - int expr_id, std::vector select_path, - std::vector qualifiers, bool presence_test, - ABSL_ATTRIBUTE_UNUSED bool enable_wrapper_type_null_unboxing, - SelectOptimizationOptions options) - : ExpressionStepBase(expr_id), - select_path_(std::move(select_path)), + OptimizedSelectImpl(std::vector select_path, + std::vector qualifiers, + bool presence_test, SelectOptimizationOptions options) + : select_path_(std::move(select_path)), qualifiers_(std::move(qualifiers)), presence_test_(presence_test), options_(options) @@ -545,17 +546,24 @@ class OptimizedSelectStep : public ExpressionStepBase { ABSL_DCHECK(!select_path_.empty()); } - absl::Status Evaluate(ExecutionFrame* frame) const override; + // Move constructible. + OptimizedSelectImpl(const OptimizedSelectImpl&) = delete; + OptimizedSelectImpl& operator=(const OptimizedSelectImpl&) = delete; + OptimizedSelectImpl(OptimizedSelectImpl&&) = default; + OptimizedSelectImpl& operator=(OptimizedSelectImpl&&) = delete; - private: - absl::StatusOr ApplySelect(ExecutionFrame* frame, + absl::StatusOr ApplySelect(ExecutionFrameBase& frame, const StructValue& struct_value) const; - // Get the effective attribute for the optimized select expression. - // Assumes the operand is the top of stack if the attribute wasn't known at - // plan time. - AttributeTrail GetAttributeTrail(ExecutionFrame* frame) const; + AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; + + absl::optional attribute() const { return attribute_; } + const std::vector& qualifiers() const { + return qualifiers_; + } + + private: absl::optional attribute_; std::vector select_path_; std::vector qualifiers_; @@ -565,12 +573,13 @@ class OptimizedSelectStep : public ExpressionStepBase { // Check for unknowns or missing attributes. absl::StatusOr> CheckForMarkedAttributes( - ExecutionFrame* frame, const AttributeTrail& attribute_trail) { + ExecutionFrameBase& frame, const AttributeTrail& attribute_trail) { if (attribute_trail.empty()) { return absl::nullopt; } - if (frame->enable_unknowns()) { + if (frame.unknown_processing_enabled() && + frame.attribute_utility().CheckForUnknownExact(attribute_trail)) { // Check if the inferred attribute is marked. Only matches if this attribute // or a parent is marked unknown (use_partial = false). // Partial matches (i.e. descendant of this attribute is marked) aren't @@ -580,49 +589,30 @@ absl::StatusOr> CheckForMarkedAttributes( // TODO(uncreated-issue/51): this may return a more specific attribute than the // declared pattern. Follow up will truncate the returned attribute to match // the pattern. - if (frame->attribute_utility().CheckForUnknown(attribute_trail, - /*use_partial=*/false)) { - return frame->attribute_utility().CreateUnknownSet( - attribute_trail.attribute()); - } + return frame.attribute_utility().CreateUnknownSet( + attribute_trail.attribute()); } - if (frame->enable_missing_attribute_errors()) { - if (frame->attribute_utility().CheckForMissingAttribute(attribute_trail)) { - return frame->attribute_utility().CreateMissingAttributeError( - attribute_trail.attribute()); - } + + if (frame.missing_attribute_errors_enabled() && + frame.attribute_utility().CheckForMissingAttribute(attribute_trail)) { + return frame.attribute_utility().CreateMissingAttributeError( + attribute_trail.attribute()); } return absl::nullopt; } -AttributeTrail OptimizedSelectStep::GetAttributeTrail( - ExecutionFrame* frame) const { - const auto& attr = frame->value_stack().PeekAttribute(); - if (attr.empty()) { - return AttributeTrail(); - } - std::vector qualifiers = - std::vector(attr.attribute().qualifier_path().begin(), - attr.attribute().qualifier_path().end()); - qualifiers.reserve(qualifiers_.size() + qualifiers.size()); - absl::c_copy(qualifiers_, std::back_inserter(qualifiers)); - AttributeTrail result(Attribute(std::string(attr.attribute().variable_name()), - std::move(qualifiers))); - return result; -} - -absl::StatusOr OptimizedSelectStep::ApplySelect( - ExecutionFrame* frame, const StructValue& struct_value) const { +absl::StatusOr OptimizedSelectImpl::ApplySelect( + ExecutionFrameBase& frame, const StructValue& struct_value) const { auto value_or = (options_.force_fallback_implementation) ? absl::UnimplementedError("Forced fallback impl") - : struct_value.Qualify(frame->value_factory(), + : struct_value.Qualify(frame.value_manager(), select_path_, presence_test_); if (!value_or.ok()) { if (value_or.status().code() == absl::StatusCode::kUnimplemented) { return FallbackSelect(struct_value, select_path_, presence_test_, - frame->value_factory()); + frame.value_manager()); } return value_or.status(); @@ -635,10 +625,47 @@ absl::StatusOr OptimizedSelectStep::ApplySelect( return FallbackSelect( value_or->first, absl::MakeConstSpan(select_path_).subspan(value_or->second), - presence_test_, frame->value_factory()); + presence_test_, frame.value_manager()); +} + +AttributeTrail OptimizedSelectImpl::GetAttributeTrail( + const AttributeTrail& operand_trail) const { + if (operand_trail.empty()) { + return AttributeTrail(); + } + std::vector qualifiers = std::vector( + operand_trail.attribute().qualifier_path().begin(), + operand_trail.attribute().qualifier_path().end()); + qualifiers.reserve(qualifiers_.size() + qualifiers.size()); + absl::c_copy(qualifiers_, std::back_inserter(qualifiers)); + return AttributeTrail( + Attribute(std::string(operand_trail.attribute().variable_name()), + std::move(qualifiers))); } -absl::Status OptimizedSelectStep::Evaluate(ExecutionFrame* frame) const { +class StackMachineImpl : public ExpressionStepBase { + public: + StackMachineImpl(int expr_id, OptimizedSelectImpl impl) + : ExpressionStepBase(expr_id), impl_(std::move(impl)) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + // Get the effective attribute for the optimized select expression. + // Assumes the operand is the top of stack if the attribute wasn't known at + // plan time. + AttributeTrail GetAttributeTrail(ExecutionFrame* frame) const; + + OptimizedSelectImpl impl_; +}; + +AttributeTrail StackMachineImpl::GetAttributeTrail( + ExecutionFrame* frame) const { + const auto& attr = frame->value_stack().PeekAttribute(); + return impl_.GetAttributeTrail(attr); +} + +absl::Status StackMachineImpl::Evaluate(ExecutionFrame* frame) const { // Default empty. AttributeTrail attribute_trail; // TODO(uncreated-issue/51): add support for variable qualifiers and string literal @@ -660,7 +687,7 @@ absl::Status OptimizedSelectStep::Evaluate(ExecutionFrame* frame) const { // TODO(uncreated-issue/51): add support variable qualifiers attribute_trail = GetAttributeTrail(frame); CEL_ASSIGN_OR_RETURN(absl::optional value, - CheckForMarkedAttributes(frame, attribute_trail)); + CheckForMarkedAttributes(*frame, attribute_trail)); if (value.has_value()) { frame->value_stack().Pop(kStackInputs); frame->value_stack().Push(std::move(value).value(), @@ -669,21 +696,72 @@ absl::Status OptimizedSelectStep::Evaluate(ExecutionFrame* frame) const { } } - - if (!operand->Is()) { return absl::InvalidArgumentError( "Expected struct type for select optimization."); } CEL_ASSIGN_OR_RETURN(Value result, - ApplySelect(frame, operand->As())); + impl_.ApplySelect(*frame, operand->As())); frame->value_stack().Pop(kStackInputs); frame->value_stack().Push(std::move(result), std::move(attribute_trail)); return absl::OkStatus(); } +class RecursiveImpl : public DirectExpressionStep { + public: + RecursiveImpl(int64_t expr_id, std::unique_ptr operand, + OptimizedSelectImpl impl) + : DirectExpressionStep(expr_id), + operand_(std::move(operand)), + impl_(std::move(impl)) {} + + absl::Status Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const override; + + private: + // Get the effective attribute for the optimized select expression. + // Assumes the operand is the top of stack if the attribute wasn't known at + // plan time. + AttributeTrail GetAttributeTrail(const AttributeTrail& operand_trail) const; + std::unique_ptr operand_; + OptimizedSelectImpl impl_; +}; + +AttributeTrail RecursiveImpl::GetAttributeTrail( + const AttributeTrail& operand_trail) const { + return impl_.GetAttributeTrail(operand_trail); +} + +absl::Status RecursiveImpl::Evaluate(ExecutionFrameBase& frame, Value& result, + AttributeTrail& attribute) const { + CEL_RETURN_IF_ERROR(operand_->Evaluate(frame, result, attribute)); + + if (InstanceOf(result) || InstanceOf(result)) { + // Just forward. + return absl::OkStatus(); + } + + if (frame.attribute_tracking_enabled()) { + attribute = impl_.GetAttributeTrail(attribute); + CEL_ASSIGN_OR_RETURN(auto value, + CheckForMarkedAttributes(frame, attribute)); + if (value.has_value()) { + result = std::move(value).value(); + return absl::OkStatus(); + } + } + + if (!InstanceOf(result)) { + return absl::InvalidArgumentError( + "Expected struct type for select optimization"); + } + CEL_ASSIGN_OR_RETURN(result, + impl_.ApplySelect(frame, Cast(result))); + return absl::OkStatus(); +} + class SelectOptimizer : public ProgramOptimizer { public: explicit SelectOptimizer(const SelectOptimizationOptions& options) @@ -758,6 +836,28 @@ absl::Status SelectOptimizer::OnPostVisit(PlannerContext& context, // TODO(uncreated-issue/51): If the first argument is a string literal, the custom // step needs to handle variable lookup. + auto* subexpression = context.program_builder().GetSubexpression(&node); + if (subexpression == nullptr || subexpression->IsFlattened()) { + // No information on the subprogram, can't optimize. + return absl::OkStatus(); + } + + OptimizedSelectImpl impl(std::move(instructions), std::move(qualifiers), + presence_test, options_); + + if (subexpression->IsRecursive()) { + auto program = subexpression->ExtractRecursiveProgram(); + auto deps = program.step->ExtractDependencies(); + if (!deps.has_value() || deps->empty()) { + return absl::InvalidArgumentError("Unexpected cel.@attribute call"); + } + subexpression->set_recursive_program( + std::make_unique(node.id(), std::move(deps->at(0)), + std::move(impl)), + program.depth); + return absl::OkStatus(); + } + google::api::expr::runtime::ExecutionPath path; // else, we need to preserve the original plan for the first argument. @@ -768,11 +868,8 @@ absl::Status SelectOptimizer::OnPostVisit(PlannerContext& context, CEL_ASSIGN_OR_RETURN(auto operand_subplan, context.ExtractSubplan(operand)); absl::c_move(operand_subplan, std::back_inserter(path)); - bool enable_wrapper_type_null_unboxing = - context.options().enable_empty_wrapper_null_unboxing; - path.push_back(std::make_unique( - node.id(), std::move(instructions), std::move(qualifiers), presence_test, - enable_wrapper_type_null_unboxing, options_)); + path.push_back( + std::make_unique(node.id(), std::move(impl))); return context.ReplaceSubplan(node, std::move(path)); } diff --git a/runtime/runtime_options.h b/runtime/runtime_options.h index 351213b7e..8363a1dc8 100644 --- a/runtime/runtime_options.h +++ b/runtime/runtime_options.h @@ -134,6 +134,17 @@ struct RuntimeOptions { // This is now always enabled. Setting this option has no effect. It will be // removed in a later update. bool enable_lazy_bind_initialization = true; + + // Maximum recursion depth for evaluable programs. + // + // -1 means unbounded. + int max_recursion_depth = -1; // DO NOT SUBMIT + + // Enable tracing support for recursively planned programs. + // + // Unlike the stack machine implementation, supporting tracing can affect + // performance whether or not tracing is requested for a given evaluation. + bool enable_recursive_tracing = true; }; // LINT.ThenChange(//depot/google3/eval/public/cel_options.h) diff --git a/tools/branch_coverage_test.cc b/tools/branch_coverage_test.cc index 1d211e191..235d11ffc 100644 --- a/tools/branch_coverage_test.cc +++ b/tools/branch_coverage_test.cc @@ -92,6 +92,17 @@ MATCHER_P(MatchesNodeStats, expected, "") { actual.error_count == expected.error_count; } +MATCHER(NodeStatsIsBool, "") { + const BranchCoverage::NodeCoverageStats& actual = arg; + + *result_listener << "\n"; + *result_listener << "Expected: " << FormatNodeStats({true, 0, 0, 0, 0}); + *result_listener << "\n"; + *result_listener << "Got: " << FormatNodeStats(actual); + + return actual.is_boolean == true; +} + TEST(BranchCoverage, DefaultsForUntrackedId) { auto coverage = CreateBranchCoverage(TestExpression()); @@ -204,13 +215,8 @@ TEST(BranchCoverage, IncrementsCounters) { ASSERT_NE(ternary, nullptr); auto ternary_node_stats = coverage->StatsForNode(ternary->expr()->id()); // Ternary gets optimized to conditional jumps, so it isn't instrumented - // directly. - EXPECT_THAT(ternary_node_stats, - MatchesNodeStats(Stats{/*is_boolean=*/true, - /*evaluation_count=*/0, - /*boolean_true_count=*/0, - /*boolean_false_count=*/0, - /*error_count=*/0})); + // directly in stack machine impl. + EXPECT_THAT(ternary_node_stats, NodeStatsIsBool()); const auto* false_node = ternary->children().at(2); auto false_node_stats = coverage->StatsForNode(false_node->expr()->id()); @@ -328,13 +334,9 @@ TEST(BranchCoverage, AccumulatesAcrossRuns) { ASSERT_NE(ternary, nullptr); auto ternary_node_stats = coverage->StatsForNode(ternary->expr()->id()); - // Ternary gets optimized into conditional jumps. - EXPECT_THAT(ternary_node_stats, - MatchesNodeStats(Stats{/*is_boolean=*/true, - /*evaluation_count=*/0, - /*boolean_true_count=*/0, - /*boolean_false_count=*/0, - /*error_count=*/0})); + + // Ternary gets optimized into conditional jumps for stack machine plan. + EXPECT_THAT(ternary_node_stats, NodeStatsIsBool()); const auto* false_node = ternary->children().at(2); auto false_node_stats = coverage->StatsForNode(false_node->expr()->id());