Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cc_library(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/functional:overload",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -992,6 +993,7 @@ cc_library(
srcs = ["ast_proto.cc"],
hdrs = ["ast_proto.h"],
deps = [
":ast",
":constant",
":expr",
"//base:ast",
Expand Down Expand Up @@ -1022,18 +1024,26 @@ cc_test(
deps = [
":ast",
":ast_proto",
":decl",
":expr",
":source",
":type",
"//common/ast:ast_impl",
"//common/ast:expr",
"//compiler",
"//compiler:compiler_factory",
"//compiler:optional",
"//compiler:standard_library",
"//extensions:comprehensions_v2",
"//internal:proto_matchers",
"//internal:status_macros",
"//internal:testing",
"//parser",
"//parser:options",
"//internal:testing_descriptor_pool",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:variant",
"@com_google_cel_spec//proto/cel/expr:checked_cc_proto",
Expand Down
5 changes: 2 additions & 3 deletions common/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ class Ast {
expr_version_(std::move(expr_version)),
is_checked_(true) {}

// Move-only
Ast(const Ast& other) = delete;
Ast& operator=(const Ast& other) = delete;
Ast(const Ast& other) = default;
Ast& operator=(const Ast& other) = default;
Ast(Ast&& other) = default;
Ast& operator=(Ast&& other) = default;

Expand Down
10 changes: 10 additions & 0 deletions common/ast/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ class SourceInfo {
macro_calls_(std::move(macro_calls)),
extensions_(std::move(extensions)) {}

SourceInfo(const SourceInfo& other) = default;
SourceInfo(SourceInfo&& other) = default;
SourceInfo& operator=(const SourceInfo& other) = default;
SourceInfo& operator=(SourceInfo&& other) = default;

void set_syntax_version(std::string syntax_version) {
syntax_version_ = std::move(syntax_version);
}
Expand Down Expand Up @@ -787,6 +792,11 @@ class Reference {
overload_id_(std::move(overload_id)),
value_(std::move(value)) {}

Reference(const Reference& other) = default;
Reference& operator=(const Reference& other) = default;
Reference(Reference&&) = default;
Reference& operator=(Reference&&) = default;

void set_name(std::string name) { name_ = std::move(name); }

void set_overload_id(std::vector<std::string> overload_id) {
Expand Down
24 changes: 10 additions & 14 deletions common/ast_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/variant.h"
#include "base/ast.h"
#include "common/ast.h"
#include "common/ast/ast_impl.h"
#include "common/ast/constant_proto.h"
#include "common/ast/expr.h"
Expand Down Expand Up @@ -499,12 +499,10 @@ absl::StatusOr<std::unique_ptr<Ast>> CreateAstFromParsedExpr(

absl::Status AstToParsedExpr(const Ast& ast,
cel::expr::ParsedExpr* absl_nonnull out) {
const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(ast);
ParsedExprPb& parsed_expr = *out;
CEL_RETURN_IF_ERROR(
ExprToProto(ast_impl.root_expr(), parsed_expr.mutable_expr()));
CEL_RETURN_IF_ERROR(ExprToProto(ast.root_expr(), parsed_expr.mutable_expr()));
CEL_RETURN_IF_ERROR(ast_internal::SourceInfoToProto(
ast_impl.source_info(), parsed_expr.mutable_source_info()));
ast.source_info(), parsed_expr.mutable_source_info()));

return absl::OkStatus();
}
Expand Down Expand Up @@ -539,25 +537,23 @@ absl::StatusOr<std::unique_ptr<Ast>> CreateAstFromCheckedExpr(

absl::Status AstToCheckedExpr(
const Ast& ast, cel::expr::CheckedExpr* absl_nonnull out) {
if (!ast.IsChecked()) {
if (!ast.is_checked()) {
return absl::InvalidArgumentError("AST is not type-checked");
}
const auto& ast_impl = ast_internal::AstImpl::CastFromPublicAst(ast);
CheckedExprPb& checked_expr = *out;
checked_expr.set_expr_version(ast_impl.expr_version());
checked_expr.set_expr_version(ast.expr_version());
CEL_RETURN_IF_ERROR(
ExprToProto(ast_impl.root_expr(), checked_expr.mutable_expr()));
ExprToProto(ast.root_expr(), checked_expr.mutable_expr()));
CEL_RETURN_IF_ERROR(ast_internal::SourceInfoToProto(
ast_impl.source_info(), checked_expr.mutable_source_info()));
for (auto it = ast_impl.reference_map().begin();
it != ast_impl.reference_map().end(); ++it) {
ast.source_info(), checked_expr.mutable_source_info()));
for (auto it = ast.reference_map().begin(); it != ast.reference_map().end();
++it) {
ReferencePb& dest_reference =
(*checked_expr.mutable_reference_map())[it->first];
CEL_ASSIGN_OR_RETURN(dest_reference, ReferenceToProto(it->second));
}

for (auto it = ast_impl.type_map().begin(); it != ast_impl.type_map().end();
++it) {
for (auto it = ast.type_map().begin(); it != ast.type_map().end(); ++it) {
TypePb& dest_type = (*checked_expr.mutable_type_map())[it->first];
CEL_RETURN_IF_ERROR(TypeToProto(it->second, &dest_type));
}
Expand Down
116 changes: 88 additions & 28 deletions common/ast_proto_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,25 @@
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/variant.h"
#include "common/ast.h"
#include "common/ast/ast_impl.h"
#include "common/ast/expr.h"
#include "common/decl.h"
#include "common/expr.h"
#include "common/source.h"
#include "common/type.h"
#include "compiler/compiler.h"
#include "compiler/compiler_factory.h"
#include "compiler/optional.h"
#include "compiler/standard_library.h"
#include "extensions/comprehensions_v2.h"
#include "internal/proto_matchers.h"
#include "internal/status_macros.h"
#include "internal/testing.h"
#include "parser/options.h"
#include "parser/parser.h"
#include "internal/testing_descriptor_pool.h"
#include "google/protobuf/text_format.h"

namespace cel {
Expand All @@ -51,7 +59,6 @@ using ::cel::ast_internal::WellKnownType;
using ::cel::internal::test::EqualsProto;
using ::cel::expr::CheckedExpr;
using ::cel::expr::ParsedExpr;
using ::google::api::expr::parser::Parse;
using ::testing::HasSubstr;

using TypePb = cel::expr::Type;
Expand Down Expand Up @@ -804,17 +811,50 @@ class ConversionRoundTripTest
: public testing::TestWithParam<ConversionRoundTripCase> {
public:
ConversionRoundTripTest() {
options_.add_macro_calls = true;
options_.enable_optional_syntax = true;
auto builder =
cel::NewCompilerBuilder(internal::GetTestingDescriptorPool()).value();
builder->AddLibrary(cel::StandardCompilerLibrary()).IgnoreError();
builder->AddLibrary(OptionalCompilerLibrary()).IgnoreError();
builder->AddLibrary(extensions::ComprehensionsV2CompilerLibrary())
.IgnoreError();
builder->GetCheckerBuilder().set_container("cel.expr.conformance.proto3");
builder->GetCheckerBuilder()
.AddVariable(MakeVariableDecl("ident", IntType()))
.IgnoreError();
builder->GetCheckerBuilder()
.AddVariable(MakeVariableDecl("map_ident", JsonMapType()))
.IgnoreError();
compiler_ = builder->Build().value();
}

absl::StatusOr<ParsedExpr> ParseToProto(absl::string_view expr) {
CEL_ASSIGN_OR_RETURN(auto source, cel::NewSource(expr));

CEL_ASSIGN_OR_RETURN(auto result, compiler_->GetParser().Parse(*source));
ParsedExpr parsed_expr;

CEL_RETURN_IF_ERROR(AstToParsedExpr(*result, &parsed_expr));
return parsed_expr;
}

absl::StatusOr<CheckedExpr> CompileToProto(absl::string_view expr) {
CEL_ASSIGN_OR_RETURN(auto result, compiler_->Compile(expr));
if (!result.IsValid()) {
return absl::InvalidArgumentError(absl::StrCat(
"Compilation failed: '", expr, "': ", result.FormatError()));
}
CEL_ASSIGN_OR_RETURN(auto ast, result.ReleaseAst());
CheckedExpr checked_expr;
CEL_RETURN_IF_ERROR(AstToCheckedExpr(*ast, &checked_expr));
return checked_expr;
}

protected:
ParserOptions options_;
std::unique_ptr<Compiler> compiler_;
};

TEST_P(ConversionRoundTripTest, ParsedExprCopyable) {
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,
Parse(GetParam().expr, "<input>", options_));
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseToProto(GetParam().expr));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
CreateAstFromParsedExpr(parsed_expr));
Expand All @@ -825,31 +865,52 @@ TEST_P(ConversionRoundTripTest, ParsedExprCopyable) {
EXPECT_THAT(AstToCheckedExpr(impl, &expr_pb),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("AST is not type-checked")));
ParsedExpr copy;
ASSERT_THAT(AstToParsedExpr(impl, &copy), IsOk());
EXPECT_THAT(copy, EqualsProto(parsed_expr));
ParsedExpr proto_out;
ASSERT_THAT(AstToParsedExpr(impl, &proto_out), IsOk());
EXPECT_THAT(proto_out, EqualsProto(parsed_expr));
}

TEST_P(ConversionRoundTripTest, CheckedExprCopyable) {
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,
Parse(GetParam().expr, "<input>", options_));
TEST_P(ConversionRoundTripTest, ExprCopyable) {
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseToProto(GetParam().expr));

CheckedExpr checked_expr;
*checked_expr.mutable_expr() = parsed_expr.expr();
*checked_expr.mutable_source_info() = parsed_expr.source_info();
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
CreateAstFromParsedExpr(parsed_expr));

Expr copy = ast->root_expr();
ast->mutable_root_expr() = std::move(copy);

ParsedExpr parsed_pb_out;
CheckedExpr checked_pb_out;
EXPECT_THAT(AstToCheckedExpr(*ast, &checked_pb_out),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("AST is not type-checked")));
ASSERT_THAT(AstToParsedExpr(*ast, &parsed_pb_out), IsOk());
EXPECT_THAT(parsed_pb_out, EqualsProto(parsed_expr));
}

int64_t root_id = checked_expr.expr().id();
(*checked_expr.mutable_reference_map())[root_id].add_overload_id("_==_");
(*checked_expr.mutable_type_map())[root_id].set_primitive(TypePb::BOOL);
TEST_P(ConversionRoundTripTest, CheckedExprRoundTrip) {
ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr,
CompileToProto(GetParam().expr));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
CreateAstFromCheckedExpr(checked_expr));

const auto& impl = ast_internal::AstImpl::CastFromPublicAst(*ast);
CheckedExpr checked_pb_out;
ASSERT_THAT(AstToCheckedExpr(*ast, &checked_pb_out), IsOk());
EXPECT_THAT(checked_pb_out, EqualsProto(checked_expr));
}

CheckedExpr expr_pb;
ASSERT_THAT(AstToCheckedExpr(impl, &expr_pb), IsOk());
EXPECT_THAT(expr_pb, EqualsProto(checked_expr));
TEST_P(ConversionRoundTripTest, CheckedExprCopyRoundTrip) {
ASSERT_OK_AND_ASSIGN(CheckedExpr checked_expr,
CompileToProto(GetParam().expr));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
CreateAstFromCheckedExpr(checked_expr));

Ast copy = *ast;
CheckedExpr checked_pb_out;
ASSERT_THAT(AstToCheckedExpr(copy, &checked_pb_out), IsOk());
EXPECT_THAT(checked_pb_out, EqualsProto(checked_expr));
}

INSTANTIATE_TEST_SUITE_P(
Expand All @@ -863,11 +924,12 @@ INSTANTIATE_TEST_SUITE_P(
{R"cel("42" == "42")cel"},
{R"cel("s".startsWith("s") == true)cel"},
{R"cel([1, 2, 3] == [1, 2, 3])cel"},
{R"cel([1, 2, 3].all(i, e, i == e - 1) == true)cel"},
{R"cel(TestAllTypes{single_int64: 42}.single_int64 == 42)cel"},
{R"cel([1, 2, 3].map(x, x + 2).size() == 3)cel"},
{R"cel({"a": 1, "b": 2}["a"] == 1)cel"},
{R"cel(ident == 42)cel"},
{R"cel(ident.field == 42)cel"},
{R"cel(map_ident.field == 42)cel"},
{R"cel({?"abc": {}[?1]}.?abc.orValue(42) == 42)cel"},
{R"cel([1, 2, ?optional.none()].size() == 2)cel"}}));

Expand Down Expand Up @@ -895,10 +957,8 @@ TEST(ExtensionConversionRoundTripTest, RoundTrip) {
ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> ast,
CreateAstFromParsedExpr(parsed_expr));

const auto& impl = ast_internal::AstImpl::CastFromPublicAst(*ast);

CheckedExpr expr_pb;
EXPECT_THAT(AstToCheckedExpr(impl, &expr_pb),
EXPECT_THAT(AstToCheckedExpr(*ast, &expr_pb),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("AST is not type-checked")));
ParsedExpr copy;
Expand Down
Loading