From 0b61b18ff5242003e173da18671a1574ed88b898 Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Fri, 27 Jun 2025 19:08:42 +0200 Subject: [PATCH] feat(core): add context to RelVisitor, ExpressionVisitor and FuncArgVisitor BREAKING CHANGE: adds a new context argument to RelVisitor, ExpressionVisitor and FuncArgVisitor Signed-off-by: Niels Pardon --- .../expression/AbstractExpressionVisitor.java | 168 ++++++------ .../java/io/substrait/expression/EnumArg.java | 8 +- .../io/substrait/expression/Expression.java | 249 ++++++++++++------ .../expression/ExpressionVisitor.java | 86 +++--- .../substrait/expression/FieldReference.java | 7 +- .../io/substrait/expression/FunctionArg.java | 31 ++- .../proto/ExpressionProtoConverter.java | 138 ++++++---- .../ExtendedExpressionProtoConverter.java | 2 +- .../relation/AbstractRelVisitor.java | 111 ++++---- .../java/io/substrait/relation/Aggregate.java | 6 +- .../AggregateFunctionProtoConverter.java | 2 +- .../relation/ConsistentPartitionWindow.java | 6 +- .../substrait/relation/CopyOnWriteUtils.java | 16 +- .../java/io/substrait/relation/Cross.java | 6 +- .../java/io/substrait/relation/EmptyScan.java | 6 +- .../java/io/substrait/relation/Expand.java | 6 +- .../ExpressionCopyOnWriteVisitor.java | 215 +++++++++------ .../io/substrait/relation/ExtensionDdl.java | 6 +- .../io/substrait/relation/ExtensionLeaf.java | 6 +- .../io/substrait/relation/ExtensionMulti.java | 6 +- .../substrait/relation/ExtensionSingle.java | 6 +- .../io/substrait/relation/ExtensionTable.java | 6 +- .../io/substrait/relation/ExtensionWrite.java | 6 +- .../java/io/substrait/relation/Fetch.java | 6 +- .../java/io/substrait/relation/Filter.java | 6 +- .../main/java/io/substrait/relation/Join.java | 6 +- .../io/substrait/relation/LocalFiles.java | 6 +- .../java/io/substrait/relation/NamedDdl.java | 6 +- .../java/io/substrait/relation/NamedScan.java | 6 +- .../io/substrait/relation/NamedUpdate.java | 6 +- .../io/substrait/relation/NamedWrite.java | 6 +- .../java/io/substrait/relation/Project.java | 6 +- .../main/java/io/substrait/relation/Rel.java | 4 +- .../relation/RelCopyOnWriteVisitor.java | 216 ++++++++------- .../substrait/relation/RelProtoConverter.java | 71 ++--- .../io/substrait/relation/RelVisitor.java | 55 ++-- .../main/java/io/substrait/relation/Set.java | 6 +- .../main/java/io/substrait/relation/Sort.java | 6 +- .../substrait/relation/VirtualTableScan.java | 6 +- .../substrait/relation/physical/HashJoin.java | 6 +- .../relation/physical/MergeJoin.java | 6 +- .../relation/physical/NestedLoopJoin.java | 6 +- .../src/main/java/io/substrait/type/Type.java | 8 +- .../util/EmptyVisitationContext.java | 3 + .../io/substrait/util/VisitationContext.java | 3 + .../type/proto/GenericRoundtripTest.java | 2 +- .../type/proto/IfThenRoundtripTest.java | 4 +- .../type/proto/LiteralRoundtripTest.java | 2 +- .../io/substrait/isthmus/SchemaCollector.java | 7 +- .../isthmus/SubstraitRelNodeConverter.java | 96 ++++--- .../substrait/isthmus/SubstraitToCalcite.java | 9 +- .../isthmus/expression/CallConverters.java | 4 +- .../expression/ExpressionRexConverter.java | 176 ++++++++----- .../io/substrait/isthmus/CalciteCallTest.java | 2 +- .../substrait/isthmus/CalciteLiteralTest.java | 11 +- .../isthmus/ExpressionConvertabilityTest.java | 13 +- .../isthmus/FunctionConversionTest.java | 28 +- .../isthmus/ProtoPlanConverterTest.java | 10 +- .../isthmus/RelCopyOnWriteVisitorTest.java | 34 +-- .../isthmus/RelExtensionRoundtripTest.java | 25 +- .../SubstraitExpressionConverterTest.java | 6 +- .../substrait/debug/ExpressionToString.scala | 33 ++- .../substrait/debug/RelToVerboseString.scala | 35 +-- .../spark/DefaultExpressionVisitor.scala | 35 ++- .../substrait/spark/DefaultRelVisitor.scala | 5 +- .../spark/expression/ToSparkExpression.scala | 120 +++++---- .../spark/logical/ToLogicalPlan.scala | 87 +++--- .../spark/logical/ToSubstraitRel.scala | 2 +- .../spark/SubstraitPlanTestBase.scala | 2 +- .../spark/TypesAndLiteralsSuite.scala | 4 +- .../SubstraitExpressionTestBase.scala | 2 +- 71 files changed, 1364 insertions(+), 939 deletions(-) create mode 100644 core/src/main/java/io/substrait/util/EmptyVisitationContext.java create mode 100644 core/src/main/java/io/substrait/util/VisitationContext.java diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 5f30a1931..072507295 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -1,206 +1,208 @@ package io.substrait.expression; -public abstract class AbstractExpressionVisitor - implements ExpressionVisitor { - public abstract OUTPUT visitFallback(Expression expr); +import io.substrait.util.VisitationContext; + +public abstract class AbstractExpressionVisitor + implements ExpressionVisitor { + public abstract O visitFallback(Expression expr, C context); @Override - public OUTPUT visit(Expression.NullLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.NullLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.BoolLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.BoolLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.I8Literal expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.I8Literal expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.I16Literal expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.I16Literal expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.I32Literal expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.I32Literal expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.I64Literal expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.I64Literal expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.FP32Literal expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.FP32Literal expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.FP64Literal expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.FP64Literal expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.StrLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.StrLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.BinaryLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.BinaryLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.TimeLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.TimeLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.DateLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.DateLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.TimestampLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.TimestampLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.TimestampTZLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.TimestampTZLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.PrecisionTimestampLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.PrecisionTimestampLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.PrecisionTimestampTZLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.PrecisionTimestampTZLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.IntervalYearLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.IntervalYearLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.IntervalDayLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.IntervalDayLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.IntervalCompoundLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.IntervalCompoundLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.UUIDLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.UUIDLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.FixedCharLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.FixedCharLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.VarCharLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.VarCharLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.FixedBinaryLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.FixedBinaryLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.DecimalLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.DecimalLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.MapLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.MapLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.EmptyMapLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.EmptyMapLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.ListLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.ListLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.EmptyListLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.EmptyListLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.StructLiteral expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.Switch expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.Switch expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.IfThen expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.IfThen expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.ScalarFunctionInvocation expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.ScalarFunctionInvocation expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.WindowFunctionInvocation expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.WindowFunctionInvocation expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.Cast expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.Cast expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.SingleOrList expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.SingleOrList expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.MultiOrList expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.MultiOrList expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(FieldReference expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(FieldReference expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.SetPredicate expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.SetPredicate expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.ScalarSubquery expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.ScalarSubquery expr, C context) throws E { + return visitFallback(expr, context); } @Override - public OUTPUT visit(Expression.InPredicate expr) throws EXCEPTION { - return visitFallback(expr); + public O visit(Expression.InPredicate expr, C context) throws E { + return visitFallback(expr, context); } } diff --git a/core/src/main/java/io/substrait/expression/EnumArg.java b/core/src/main/java/io/substrait/expression/EnumArg.java index e006016b6..cca316d75 100644 --- a/core/src/main/java/io/substrait/expression/EnumArg.java +++ b/core/src/main/java/io/substrait/expression/EnumArg.java @@ -1,6 +1,7 @@ package io.substrait.expression; import io.substrait.extension.SimpleExtension; +import io.substrait.util.VisitationContext; import java.util.Optional; import org.immutables.value.Value; @@ -16,9 +17,10 @@ public interface EnumArg extends FunctionArg { Optional value(); @Override - default R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) throws E { - return fnArgVisitor.visitEnumArg(fnDef, argIdx, this); + default R accept( + SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + throws E { + return fnArgVisitor.visitEnumArg(fnDef, argIdx, this, context); } static EnumArg of(SimpleExtension.EnumArgument enumArg, String option) { diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index aa9e69148..75e003f53 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -6,6 +6,7 @@ import io.substrait.relation.Rel; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -18,9 +19,10 @@ public interface Expression extends FunctionArg { Type getType(); @Override - default R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) throws E { - return fnArgVisitor.visitExpr(fnDef, argIdx, this); + default R accept( + SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + throws E { + return fnArgVisitor.visitExpr(fnDef, argIdx, this, context); } interface Literal extends Expression { @@ -30,7 +32,8 @@ default boolean nullable() { } } - R accept(ExpressionVisitor visitor) throws E; + R accept( + ExpressionVisitor visitor, C context) throws E; @Value.Immutable abstract static class NullLiteral implements Literal { @@ -44,8 +47,10 @@ public static ImmutableExpression.NullLiteral.Builder builder() { return ImmutableExpression.NullLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -61,8 +66,10 @@ public static ImmutableExpression.BoolLiteral.Builder builder() { return ImmutableExpression.BoolLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -78,8 +85,10 @@ public static ImmutableExpression.I8Literal.Builder builder() { return ImmutableExpression.I8Literal.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -95,8 +104,10 @@ public static ImmutableExpression.I16Literal.Builder builder() { return ImmutableExpression.I16Literal.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -112,8 +123,10 @@ public static ImmutableExpression.I32Literal.Builder builder() { return ImmutableExpression.I32Literal.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -129,8 +142,10 @@ public static ImmutableExpression.I64Literal.Builder builder() { return ImmutableExpression.I64Literal.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -146,8 +161,10 @@ public static ImmutableExpression.FP32Literal.Builder builder() { return ImmutableExpression.FP32Literal.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -163,8 +180,10 @@ public static ImmutableExpression.FP64Literal.Builder builder() { return ImmutableExpression.FP64Literal.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -180,8 +199,10 @@ public static ImmutableExpression.StrLiteral.Builder builder() { return ImmutableExpression.StrLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -197,8 +218,10 @@ public static ImmutableExpression.BinaryLiteral.Builder builder() { return ImmutableExpression.BinaryLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -214,8 +237,10 @@ public static ImmutableExpression.TimestampLiteral.Builder builder() { return ImmutableExpression.TimestampLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -231,8 +256,10 @@ public static ImmutableExpression.TimeLiteral.Builder builder() { return ImmutableExpression.TimeLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -248,8 +275,9 @@ public static ImmutableExpression.DateLiteral.Builder builder() { return ImmutableExpression.DateLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -265,8 +293,10 @@ public static ImmutableExpression.TimestampTZLiteral.Builder builder() { return ImmutableExpression.TimestampTZLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -284,8 +314,10 @@ public static ImmutableExpression.PrecisionTimestampLiteral.Builder builder() { return ImmutableExpression.PrecisionTimestampLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -303,8 +335,10 @@ public static ImmutableExpression.PrecisionTimestampTZLiteral.Builder builder() return ImmutableExpression.PrecisionTimestampTZLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -322,8 +356,10 @@ public static ImmutableExpression.IntervalYearLiteral.Builder builder() { return ImmutableExpression.IntervalYearLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -345,8 +381,10 @@ public static ImmutableExpression.IntervalDayLiteral.Builder builder() { return ImmutableExpression.IntervalDayLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -374,8 +412,10 @@ public static ImmutableExpression.IntervalCompoundLiteral.Builder builder() { return ImmutableExpression.IntervalCompoundLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -391,8 +431,10 @@ public static ImmutableExpression.UUIDLiteral.Builder builder() { return ImmutableExpression.UUIDLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public ByteString toBytes() { @@ -416,8 +458,10 @@ public static ImmutableExpression.FixedCharLiteral.Builder builder() { return ImmutableExpression.FixedCharLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -435,8 +479,10 @@ public static ImmutableExpression.VarCharLiteral.Builder builder() { return ImmutableExpression.VarCharLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -452,8 +498,10 @@ public static ImmutableExpression.FixedBinaryLiteral.Builder builder() { return ImmutableExpression.FixedBinaryLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -473,8 +521,10 @@ public static ImmutableExpression.DecimalLiteral.Builder builder() { return ImmutableExpression.DecimalLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -493,8 +543,10 @@ public static ImmutableExpression.MapLiteral.Builder builder() { return ImmutableExpression.MapLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -512,8 +564,10 @@ public static ImmutableExpression.EmptyMapLiteral.Builder builder() { return ImmutableExpression.EmptyMapLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -529,8 +583,10 @@ public static ImmutableExpression.ListLiteral.Builder builder() { return ImmutableExpression.ListLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -548,8 +604,9 @@ public static ImmutableExpression.EmptyListLiteral.Builder builder() { } @Override - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -569,8 +626,10 @@ public static ImmutableExpression.StructLiteral.Builder builder() { return ImmutableExpression.StructLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -590,8 +649,10 @@ public static ImmutableExpression.UserDefinedLiteral.Builder builder() { return ImmutableExpression.UserDefinedLiteral.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -611,8 +672,10 @@ public static ImmutableExpression.Switch.Builder builder() { return ImmutableExpression.Switch.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -647,8 +710,10 @@ public static ImmutableExpression.IfThen.Builder builder() { return ImmutableExpression.IfThen.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -679,8 +744,10 @@ public static ImmutableExpression.Cast.Builder builder() { return ImmutableExpression.Cast.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -702,8 +769,10 @@ public static ImmutableExpression.ScalarFunctionInvocation.Builder builder() { return ImmutableExpression.ScalarFunctionInvocation.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -740,8 +809,10 @@ public static ImmutableExpression.WindowFunctionInvocation.Builder builder() { return ImmutableExpression.WindowFunctionInvocation.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -786,8 +857,10 @@ public static ImmutableExpression.SingleOrList.Builder builder() { return ImmutableExpression.SingleOrList.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -805,8 +878,10 @@ public static ImmutableExpression.MultiOrList.Builder builder() { return ImmutableExpression.MultiOrList.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -846,8 +921,10 @@ public static ImmutableExpression.SetPredicate.Builder builder() { return ImmutableExpression.SetPredicate.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -859,8 +936,10 @@ public static ImmutableExpression.ScalarSubquery.Builder builder() { return ImmutableExpression.ScalarSubquery.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } @@ -878,8 +957,10 @@ public static ImmutableExpression.InPredicate.Builder builder() { return ImmutableExpression.InPredicate.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } } diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index b27a241a2..d478a4b29 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -1,87 +1,89 @@ package io.substrait.expression; -public interface ExpressionVisitor { +import io.substrait.util.VisitationContext; + +public interface ExpressionVisitor { static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ExpressionVisitor.class); - R visit(Expression.NullLiteral expr) throws E; + R visit(Expression.NullLiteral expr, C context) throws E; - R visit(Expression.BoolLiteral expr) throws E; + R visit(Expression.BoolLiteral expr, C context) throws E; - R visit(Expression.I8Literal expr) throws E; + R visit(Expression.I8Literal expr, C context) throws E; - R visit(Expression.I16Literal expr) throws E; + R visit(Expression.I16Literal expr, C context) throws E; - R visit(Expression.I32Literal expr) throws E; + R visit(Expression.I32Literal expr, C context) throws E; - R visit(Expression.I64Literal expr) throws E; + R visit(Expression.I64Literal expr, C context) throws E; - R visit(Expression.FP32Literal expr) throws E; + R visit(Expression.FP32Literal expr, C context) throws E; - R visit(Expression.FP64Literal expr) throws E; + R visit(Expression.FP64Literal expr, C context) throws E; - R visit(Expression.StrLiteral expr) throws E; + R visit(Expression.StrLiteral expr, C context) throws E; - R visit(Expression.BinaryLiteral expr) throws E; + R visit(Expression.BinaryLiteral expr, C context) throws E; - R visit(Expression.TimeLiteral expr) throws E; + R visit(Expression.TimeLiteral expr, C context) throws E; - R visit(Expression.DateLiteral expr) throws E; + R visit(Expression.DateLiteral expr, C context) throws E; - R visit(Expression.TimestampLiteral expr) throws E; + R visit(Expression.TimestampLiteral expr, C context) throws E; - R visit(Expression.TimestampTZLiteral expr) throws E; + R visit(Expression.TimestampTZLiteral expr, C context) throws E; - R visit(Expression.PrecisionTimestampLiteral expr) throws E; + R visit(Expression.PrecisionTimestampLiteral expr, C context) throws E; - R visit(Expression.PrecisionTimestampTZLiteral expr) throws E; + R visit(Expression.PrecisionTimestampTZLiteral expr, C context) throws E; - R visit(Expression.IntervalYearLiteral expr) throws E; + R visit(Expression.IntervalYearLiteral expr, C context) throws E; - R visit(Expression.IntervalDayLiteral expr) throws E; + R visit(Expression.IntervalDayLiteral expr, C context) throws E; - R visit(Expression.IntervalCompoundLiteral expr) throws E; + R visit(Expression.IntervalCompoundLiteral expr, C context) throws E; - R visit(Expression.UUIDLiteral expr) throws E; + R visit(Expression.UUIDLiteral expr, C context) throws E; - R visit(Expression.FixedCharLiteral expr) throws E; + R visit(Expression.FixedCharLiteral expr, C context) throws E; - R visit(Expression.VarCharLiteral expr) throws E; + R visit(Expression.VarCharLiteral expr, C context) throws E; - R visit(Expression.FixedBinaryLiteral expr) throws E; + R visit(Expression.FixedBinaryLiteral expr, C context) throws E; - R visit(Expression.DecimalLiteral expr) throws E; + R visit(Expression.DecimalLiteral expr, C context) throws E; - R visit(Expression.MapLiteral expr) throws E; + R visit(Expression.MapLiteral expr, C context) throws E; - R visit(Expression.EmptyMapLiteral expr) throws E; + R visit(Expression.EmptyMapLiteral expr, C context) throws E; - R visit(Expression.ListLiteral expr) throws E; + R visit(Expression.ListLiteral expr, C context) throws E; - R visit(Expression.EmptyListLiteral expr) throws E; + R visit(Expression.EmptyListLiteral expr, C context) throws E; - R visit(Expression.StructLiteral expr) throws E; + R visit(Expression.StructLiteral expr, C context) throws E; - R visit(Expression.UserDefinedLiteral expr) throws E; + R visit(Expression.UserDefinedLiteral expr, C context) throws E; - R visit(Expression.Switch expr) throws E; + R visit(Expression.Switch expr, C context) throws E; - R visit(Expression.IfThen expr) throws E; + R visit(Expression.IfThen expr, C context) throws E; - R visit(Expression.ScalarFunctionInvocation expr) throws E; + R visit(Expression.ScalarFunctionInvocation expr, C context) throws E; - R visit(Expression.WindowFunctionInvocation expr) throws E; + R visit(Expression.WindowFunctionInvocation expr, C context) throws E; - R visit(Expression.Cast expr) throws E; + R visit(Expression.Cast expr, C context) throws E; - R visit(Expression.SingleOrList expr) throws E; + R visit(Expression.SingleOrList expr, C context) throws E; - R visit(Expression.MultiOrList expr) throws E; + R visit(Expression.MultiOrList expr, C context) throws E; - R visit(FieldReference expr) throws E; + R visit(FieldReference expr, C context) throws E; - R visit(Expression.SetPredicate expr) throws E; + R visit(Expression.SetPredicate expr, C context) throws E; - R visit(Expression.ScalarSubquery expr) throws E; + R visit(Expression.ScalarSubquery expr, C context) throws E; - R visit(Expression.InPredicate expr) throws E; + R visit(Expression.InPredicate expr, C context) throws E; } diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 725179c7e..f2926f473 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -3,6 +3,7 @@ import io.substrait.relation.Rel; import io.substrait.type.Type; import io.substrait.type.TypeVisitor; +import io.substrait.util.VisitationContext; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -28,8 +29,10 @@ public static ImmutableFieldReference.Builder builder() { return ImmutableFieldReference.builder(); } - public R accept(ExpressionVisitor visitor) throws E { - return visitor.visit(this); + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public boolean isSimpleRootReference() { diff --git a/core/src/main/java/io/substrait/expression/FunctionArg.java b/core/src/main/java/io/substrait/expression/FunctionArg.java index 409a05d7b..495def8ad 100644 --- a/core/src/main/java/io/substrait/expression/FunctionArg.java +++ b/core/src/main/java/io/substrait/expression/FunctionArg.java @@ -6,6 +6,8 @@ import io.substrait.proto.FunctionArgument; import io.substrait.type.Type; import io.substrait.type.proto.ProtoTypeConverter; +import io.substrait.util.EmptyVisitationContext; +import io.substrait.util.VisitationContext; /** * FunctionArg is a marker interface that represents an argument of a {@link @@ -15,39 +17,44 @@ */ public interface FunctionArg { - R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) throws E; + R accept( + SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + throws E; - interface FuncArgVisitor { - R visitExpr(SimpleExtension.Function fnDef, int argIdx, Expression e) throws E; + interface FuncArgVisitor { + R visitExpr(SimpleExtension.Function fnDef, int argIdx, Expression e, C context) throws E; - R visitType(SimpleExtension.Function fnDef, int argIdx, Type t) throws E; + R visitType(SimpleExtension.Function fnDef, int argIdx, Type t, C context) throws E; - R visitEnumArg(SimpleExtension.Function fnDef, int argIdx, EnumArg e) throws E; + R visitEnumArg(SimpleExtension.Function fnDef, int argIdx, EnumArg e, C context) throws E; } - static FuncArgVisitor toProto( + static FuncArgVisitor toProto( TypeExpressionVisitor typeVisitor, - ExpressionVisitor expressionVisitor) { + ExpressionVisitor + expressionVisitor) { return new FuncArgVisitor<>() { @Override - public FunctionArgument visitExpr(SimpleExtension.Function fnDef, int argIdx, Expression e) + public FunctionArgument visitExpr( + SimpleExtension.Function fnDef, int argIdx, Expression e, EmptyVisitationContext context) throws RuntimeException { - var pE = e.accept(expressionVisitor); + var pE = e.accept(expressionVisitor, context); return FunctionArgument.newBuilder().setValue(pE).build(); } @Override - public FunctionArgument visitType(SimpleExtension.Function fnDef, int argIdx, Type t) + public FunctionArgument visitType( + SimpleExtension.Function fnDef, int argIdx, Type t, EmptyVisitationContext context) throws RuntimeException { var pTyp = t.accept(typeVisitor); return FunctionArgument.newBuilder().setType(pTyp).build(); } @Override - public FunctionArgument visitEnumArg(SimpleExtension.Function fnDef, int argIdx, EnumArg ea) + public FunctionArgument visitEnumArg( + SimpleExtension.Function fnDef, int argIdx, EnumArg ea, EmptyVisitationContext context) throws RuntimeException { var enumBldr = FunctionArgument.newBuilder(); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index da86d704e..a51a53de9 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -15,6 +15,7 @@ import io.substrait.proto.Type; import io.substrait.relation.RelProtoConverter; import io.substrait.type.proto.TypeProtoConverter; +import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -22,7 +23,8 @@ /** * Converts from {@link io.substrait.expression.Expression} to {@link io.substrait.proto.Expression} */ -public class ExpressionProtoConverter implements ExpressionVisitor { +public class ExpressionProtoConverter + implements ExpressionVisitor { protected final RelProtoConverter relProtoConverter; protected final TypeProtoConverter typeProtoConverter; @@ -45,7 +47,7 @@ public TypeProtoConverter getTypeProtoConverter() { } public io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) { - return expression.accept(this); + return expression.accept(this, null); } public List toProto( @@ -62,7 +64,8 @@ protected io.substrait.proto.Type toProto(io.substrait.type.Type type) { } @Override - public Expression visit(io.substrait.expression.Expression.NullLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.NullLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNull(toProto(expr.type()))); } @@ -73,72 +76,87 @@ private Expression lit(Consumer consumer) { } @Override - public Expression visit(io.substrait.expression.Expression.BoolLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.BoolLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setBoolean(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.I8Literal expr) { + public Expression visit( + io.substrait.expression.Expression.I8Literal expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI8(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.I16Literal expr) { + public Expression visit( + io.substrait.expression.Expression.I16Literal expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI16(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.I32Literal expr) { + public Expression visit( + io.substrait.expression.Expression.I32Literal expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI32(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.I64Literal expr) { + public Expression visit( + io.substrait.expression.Expression.I64Literal expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setI64(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.FP32Literal expr) { + public Expression visit( + io.substrait.expression.Expression.FP32Literal expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFp32(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.FP64Literal expr) { + public Expression visit( + io.substrait.expression.Expression.FP64Literal expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFp64(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.StrLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.StrLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setString(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.BinaryLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.BinaryLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setBinary(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.TimeLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.TimeLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setTime(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.DateLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.DateLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setDate(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.TimestampLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.TimestampLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setTimestamp(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.TimestampTZLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.TimestampTZLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setTimestampTz(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.PrecisionTimestampLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.PrecisionTimestampLiteral expr, + EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -151,7 +169,9 @@ public Expression visit(io.substrait.expression.Expression.PrecisionTimestampLit } @Override - public Expression visit(io.substrait.expression.Expression.PrecisionTimestampTZLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.PrecisionTimestampTZLiteral expr, + EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -164,7 +184,8 @@ public Expression visit(io.substrait.expression.Expression.PrecisionTimestampTZL } @Override - public Expression visit(io.substrait.expression.Expression.IntervalYearLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.IntervalYearLiteral expr, EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -175,7 +196,8 @@ public Expression visit(io.substrait.expression.Expression.IntervalYearLiteral e } @Override - public Expression visit(io.substrait.expression.Expression.IntervalDayLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.IntervalDayLiteral expr, EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -188,7 +210,9 @@ public Expression visit(io.substrait.expression.Expression.IntervalDayLiteral ex } @Override - public Expression visit(io.substrait.expression.Expression.IntervalCompoundLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.IntervalCompoundLiteral expr, + EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -207,17 +231,20 @@ public Expression visit(io.substrait.expression.Expression.IntervalCompoundLiter } @Override - public Expression visit(io.substrait.expression.Expression.UUIDLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.UUIDLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setUuid(expr.toBytes())); } @Override - public Expression visit(io.substrait.expression.Expression.FixedCharLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.FixedCharLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFixedChar(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.VarCharLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.VarCharLiteral expr, EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -228,12 +255,14 @@ public Expression visit(io.substrait.expression.Expression.VarCharLiteral expr) } @Override - public Expression visit(io.substrait.expression.Expression.FixedBinaryLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.FixedBinaryLiteral expr, EmptyVisitationContext context) { return lit(bldr -> bldr.setNullable(expr.nullable()).setFixedBinary(expr.value())); } @Override - public Expression visit(io.substrait.expression.Expression.DecimalLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.DecimalLiteral expr, EmptyVisitationContext context) { return lit( bldr -> bldr.setNullable(expr.nullable()) @@ -245,7 +274,8 @@ public Expression visit(io.substrait.expression.Expression.DecimalLiteral expr) } @Override - public Expression visit(io.substrait.expression.Expression.MapLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.MapLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { var keyValues = @@ -266,7 +296,8 @@ public Expression visit(io.substrait.expression.Expression.MapLiteral expr) { } @Override - public Expression visit(io.substrait.expression.Expression.EmptyMapLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.EmptyMapLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { var protoMapType = toProto(expr.getType()); @@ -281,7 +312,8 @@ public Expression visit(io.substrait.expression.Expression.EmptyMapLiteral expr) } @Override - public Expression visit(io.substrait.expression.Expression.ListLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.ListLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { var values = @@ -294,7 +326,8 @@ public Expression visit(io.substrait.expression.Expression.ListLiteral expr) { } @Override - public Expression visit(io.substrait.expression.Expression.EmptyListLiteral expr) + public Expression visit( + io.substrait.expression.Expression.EmptyListLiteral expr, EmptyVisitationContext context) throws RuntimeException { return lit( builder -> { @@ -311,7 +344,8 @@ public Expression visit(io.substrait.expression.Expression.EmptyListLiteral expr } @Override - public Expression visit(io.substrait.expression.Expression.StructLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.StructLiteral expr, EmptyVisitationContext context) { return lit( bldr -> { var values = @@ -324,7 +358,8 @@ public Expression visit(io.substrait.expression.Expression.StructLiteral expr) { } @Override - public Expression visit(io.substrait.expression.Expression.UserDefinedLiteral expr) { + public Expression visit( + io.substrait.expression.Expression.UserDefinedLiteral expr, EmptyVisitationContext context) { var typeReference = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.uri(), expr.name())); return lit( @@ -349,7 +384,8 @@ private Expression.Literal toLiteral(io.substrait.expression.Expression expressi } @Override - public Expression visit(io.substrait.expression.Expression.Switch expr) { + public Expression visit( + io.substrait.expression.Expression.Switch expr, EmptyVisitationContext context) { var clauses = expr.switchClauses().stream() .map( @@ -369,7 +405,8 @@ public Expression visit(io.substrait.expression.Expression.Switch expr) { } @Override - public Expression visit(io.substrait.expression.Expression.IfThen expr) { + public Expression visit( + io.substrait.expression.Expression.IfThen expr, EmptyVisitationContext context) { var clauses = expr.ifClauses().stream() .map( @@ -386,7 +423,9 @@ public Expression visit(io.substrait.expression.Expression.IfThen expr) { } @Override - public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocation expr) { + public Expression visit( + io.substrait.expression.Expression.ScalarFunctionInvocation expr, + EmptyVisitationContext context) { var argVisitor = FunctionArg.toProto(typeProtoConverter, this); @@ -397,7 +436,7 @@ public Expression visit(io.substrait.expression.Expression.ScalarFunctionInvocat .setFunctionReference(extensionCollector.getFunctionReference(expr.declaration())) .addAllArguments( expr.arguments().stream() - .map(a -> a.accept(expr.declaration(), 0, argVisitor)) + .map(a -> a.accept(expr.declaration(), 0, argVisitor, context)) .collect(java.util.stream.Collectors.toList())) .addAllOptions( expr.options().stream() @@ -414,7 +453,8 @@ public static FunctionOption from(io.substrait.expression.FunctionOption option) } @Override - public Expression visit(io.substrait.expression.Expression.Cast expr) { + public Expression visit( + io.substrait.expression.Expression.Cast expr, EmptyVisitationContext context) { return Expression.newBuilder() .setCast( Expression.Cast.newBuilder() @@ -425,7 +465,8 @@ public Expression visit(io.substrait.expression.Expression.Cast expr) { } @Override - public Expression visit(io.substrait.expression.Expression.SingleOrList expr) + public Expression visit( + io.substrait.expression.Expression.SingleOrList expr, EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSingularOrList( @@ -436,7 +477,8 @@ public Expression visit(io.substrait.expression.Expression.SingleOrList expr) } @Override - public Expression visit(io.substrait.expression.Expression.MultiOrList expr) + public Expression visit( + io.substrait.expression.Expression.MultiOrList expr, EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setMultiOrList( @@ -454,7 +496,7 @@ public Expression visit(io.substrait.expression.Expression.MultiOrList expr) } @Override - public Expression visit(FieldReference expr) { + public Expression visit(FieldReference expr, EmptyVisitationContext context) { Expression.ReferenceSegment seg = null; for (var segment : expr.segments()) { @@ -500,7 +542,8 @@ public Expression visit(FieldReference expr) { } @Override - public Expression visit(io.substrait.expression.Expression.SetPredicate expr) + public Expression visit( + io.substrait.expression.Expression.SetPredicate expr, EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSubquery( @@ -515,7 +558,8 @@ public Expression visit(io.substrait.expression.Expression.SetPredicate expr) } @Override - public Expression visit(io.substrait.expression.Expression.ScalarSubquery expr) + public Expression visit( + io.substrait.expression.Expression.ScalarSubquery expr, EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSubquery( @@ -527,7 +571,8 @@ public Expression visit(io.substrait.expression.Expression.ScalarSubquery expr) } @Override - public Expression visit(io.substrait.expression.Expression.InPredicate expr) + public Expression visit( + io.substrait.expression.Expression.InPredicate expr, EmptyVisitationContext context) throws RuntimeException { return Expression.newBuilder() .setSubquery( @@ -541,12 +586,15 @@ public Expression visit(io.substrait.expression.Expression.InPredicate expr) .build(); } - public Expression visit(io.substrait.expression.Expression.WindowFunctionInvocation expr) + @Override + public Expression visit( + io.substrait.expression.Expression.WindowFunctionInvocation expr, + EmptyVisitationContext context) throws RuntimeException { var argVisitor = FunctionArg.toProto(typeProtoConverter, this); List args = expr.arguments().stream() - .map(a -> a.accept(expr.declaration(), 0, argVisitor)) + .map(a -> a.accept(expr.declaration(), 0, argVisitor, context)) .collect(java.util.stream.Collectors.toList()); Type outputType = toProto(expr.getType()); diff --git a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java index 34deae39c..bec48ec7a 100644 --- a/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/extendedexpression/ExtendedExpressionProtoConverter.java @@ -27,7 +27,7 @@ public ExtendedExpression toProto( if (expressionReference instanceof io.substrait.extendedexpression.ExtendedExpression.ExpressionReference et) { io.substrait.proto.Expression expressionProto = - et.getExpression().accept(expressionProtoConverter); + et.getExpression().accept(expressionProtoConverter, null); ExpressionReference.Builder expressionReferenceBuilder = ExpressionReference.newBuilder() .setExpression(expressionProto) diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index d3bdebb1d..99f4e0577 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -3,138 +3,139 @@ import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.util.VisitationContext; -public abstract class AbstractRelVisitor - implements RelVisitor { - public abstract OUTPUT visitFallback(Rel rel); +public abstract class AbstractRelVisitor + implements RelVisitor { + public abstract O visitFallback(Rel rel, C context); @Override - public OUTPUT visit(Aggregate aggregate) throws EXCEPTION { - return visitFallback(aggregate); + public O visit(Aggregate aggregate, C context) throws E { + return visitFallback(aggregate, context); } @Override - public OUTPUT visit(EmptyScan emptyScan) throws EXCEPTION { - return visitFallback(emptyScan); + public O visit(EmptyScan emptyScan, C context) throws E { + return visitFallback(emptyScan, context); } @Override - public OUTPUT visit(Fetch fetch) throws EXCEPTION { - return visitFallback(fetch); + public O visit(Fetch fetch, C context) throws E { + return visitFallback(fetch, context); } @Override - public OUTPUT visit(Filter filter) throws EXCEPTION { - return visitFallback(filter); + public O visit(Filter filter, C context) throws E { + return visitFallback(filter, context); } @Override - public OUTPUT visit(Join join) throws EXCEPTION { - return visitFallback(join); + public O visit(Join join, C context) throws E { + return visitFallback(join, context); } @Override - public OUTPUT visit(Set set) throws EXCEPTION { - return visitFallback(set); + public O visit(Set set, C context) throws E { + return visitFallback(set, context); } @Override - public OUTPUT visit(NamedScan namedScan) throws EXCEPTION { - return visitFallback(namedScan); + public O visit(NamedScan namedScan, C context) throws E { + return visitFallback(namedScan, context); } @Override - public OUTPUT visit(LocalFiles localFiles) throws EXCEPTION { - return visitFallback(localFiles); + public O visit(LocalFiles localFiles, C context) throws E { + return visitFallback(localFiles, context); } @Override - public OUTPUT visit(Project project) throws EXCEPTION { - return visitFallback(project); + public O visit(Project project, C context) throws E { + return visitFallback(project, context); } @Override - public OUTPUT visit(Expand expand) throws EXCEPTION { - return visitFallback(expand); + public O visit(Expand expand, C context) throws E { + return visitFallback(expand, context); } @Override - public OUTPUT visit(Sort sort) throws EXCEPTION { - return visitFallback(sort); + public O visit(Sort sort, C context) throws E { + return visitFallback(sort, context); } @Override - public OUTPUT visit(VirtualTableScan virtualTableScan) throws EXCEPTION { - return visitFallback(virtualTableScan); + public O visit(Cross cross, C context) throws E { + return visitFallback(cross, context); } @Override - public OUTPUT visit(Cross cross) throws EXCEPTION { - return visitFallback(cross); + public O visit(VirtualTableScan virtualTableScan, C context) throws E { + return visitFallback(virtualTableScan, context); } @Override - public OUTPUT visit(ExtensionLeaf extensionLeaf) throws EXCEPTION { - return visitFallback(extensionLeaf); + public O visit(ExtensionLeaf extensionLeaf, C context) throws E { + return visitFallback(extensionLeaf, context); } @Override - public OUTPUT visit(ExtensionSingle extensionSingle) throws EXCEPTION { - return visitFallback(extensionSingle); + public O visit(ExtensionSingle extensionSingle, C context) throws E { + return visitFallback(extensionSingle, context); } @Override - public OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION { - return visitFallback(extensionMulti); + public O visit(ExtensionMulti extensionMulti, C context) throws E { + return visitFallback(extensionMulti, context); } @Override - public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION { - return visitFallback(extensionTable); + public O visit(ExtensionTable extensionTable, C context) throws E { + return visitFallback(extensionTable, context); } @Override - public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION { - return visitFallback(hashJoin); + public O visit(HashJoin hashJoin, C context) throws E { + return visitFallback(hashJoin, context); } @Override - public OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION { - return visitFallback(mergeJoin); + public O visit(MergeJoin mergeJoin, C context) throws E { + return visitFallback(mergeJoin, context); } @Override - public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { - return visitFallback(nestedLoopJoin); + public O visit(NestedLoopJoin nestedLoopJoin, C context) throws E { + return visitFallback(nestedLoopJoin, context); } @Override - public OUTPUT visit(ConsistentPartitionWindow consistentPartitionWindow) throws EXCEPTION { - return visitFallback(consistentPartitionWindow); + public O visit(ConsistentPartitionWindow consistentPartitionWindow, C context) throws E { + return visitFallback(consistentPartitionWindow, context); } @Override - public OUTPUT visit(NamedWrite write) throws EXCEPTION { - return visitFallback(write); + public O visit(NamedWrite write, C context) throws E { + return visitFallback(write, context); } @Override - public OUTPUT visit(ExtensionWrite write) throws EXCEPTION { - return visitFallback(write); + public O visit(ExtensionWrite write, C context) throws E { + return visitFallback(write, context); } @Override - public OUTPUT visit(NamedDdl ddl) throws EXCEPTION { - return visitFallback(ddl); + public O visit(NamedDdl ddl, C context) throws E { + return visitFallback(ddl, context); } @Override - public OUTPUT visit(ExtensionDdl ddl) throws EXCEPTION { - return visitFallback(ddl); + public O visit(ExtensionDdl ddl, C context) throws E { + return visitFallback(ddl, context); } @Override - public OUTPUT visit(NamedUpdate update) throws EXCEPTION { - return visitFallback(update); + public O visit(NamedUpdate update, C context) throws E { + return visitFallback(update, context); } } diff --git a/core/src/main/java/io/substrait/relation/Aggregate.java b/core/src/main/java/io/substrait/relation/Aggregate.java index d6425efd6..132e73439 100644 --- a/core/src/main/java/io/substrait/relation/Aggregate.java +++ b/core/src/main/java/io/substrait/relation/Aggregate.java @@ -4,6 +4,7 @@ import io.substrait.expression.Expression; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.LinkedHashSet; import java.util.List; import java.util.Optional; @@ -35,8 +36,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } @Value.Immutable diff --git a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java index e92752b3f..b53bfc90a 100644 --- a/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/AggregateFunctionProtoConverter.java @@ -34,7 +34,7 @@ public AggregateFunction toProto(Aggregate.Measure measure) { .setOutputType(measure.getFunction().getType().accept(typeProtoConverter)) .addAllArguments( IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) + .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null)) .collect(java.util.stream.Collectors.toList())) .setFunctionReference( functionCollector.getFunctionReference(measure.getFunction().declaration())) diff --git a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java index f1f9cbe71..1c736d34c 100644 --- a/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java +++ b/core/src/main/java/io/substrait/relation/ConsistentPartitionWindow.java @@ -8,6 +8,7 @@ import io.substrait.extension.SimpleExtension; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.List; import java.util.stream.Stream; import org.immutables.value.Value; @@ -33,8 +34,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableConsistentPartitionWindow.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java b/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java index bda780032..e470c526b 100644 --- a/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java +++ b/core/src/main/java/io/substrait/relation/CopyOnWriteUtils.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -23,8 +24,8 @@ public static Optional or(Optional left, Supplier { - Optional apply(T t) throws E; + public interface TransformFunction { + Optional apply(T t, C context) throws E; } /** @@ -41,12 +42,13 @@ public interface TransformFunction { * @return An empty optional if none of the items have changed. An optional containing a new list * otherwise. */ - public static Optional> transformList( - List items, TransformFunction transform) throws E { - List newItems = new ArrayList<>(); + public static + Optional> transformList( + List items, C context, TransformFunction transform) throws E { + List newItems = new ArrayList<>(); boolean listUpdated = false; - for (ITEM item : items) { - Optional newItem = transform.apply(item); + for (I item : items) { + Optional newItem = transform.apply(item, context); if (newItem.isPresent()) { newItems.add(newItem.get()); listUpdated = true; diff --git a/core/src/main/java/io/substrait/relation/Cross.java b/core/src/main/java/io/substrait/relation/Cross.java index c082957b8..b6ab4b42f 100644 --- a/core/src/main/java/io/substrait/relation/Cross.java +++ b/core/src/main/java/io/substrait/relation/Cross.java @@ -2,6 +2,7 @@ import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.stream.Stream; import org.immutables.value.Value; @@ -17,8 +18,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableCross.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/EmptyScan.java b/core/src/main/java/io/substrait/relation/EmptyScan.java index 0b47b3afd..95d304b49 100644 --- a/core/src/main/java/io/substrait/relation/EmptyScan.java +++ b/core/src/main/java/io/substrait/relation/EmptyScan.java @@ -1,13 +1,15 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable public abstract class EmptyScan extends AbstractReadRel { @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableEmptyScan.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java index 7f88282ae..63e868f63 100644 --- a/core/src/main/java/io/substrait/relation/Expand.java +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -21,8 +22,9 @@ public Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExpand.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 29bbe1a8c..8709f9d02 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -7,19 +7,20 @@ import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; -public class ExpressionCopyOnWriteVisitor - implements ExpressionVisitor, EXCEPTION> { +public class ExpressionCopyOnWriteVisitor + implements ExpressionVisitor, EmptyVisitationContext, E> { - private final RelCopyOnWriteVisitor relCopyOnWriteVisitor; + private final RelCopyOnWriteVisitor relCopyOnWriteVisitor; - public ExpressionCopyOnWriteVisitor(RelCopyOnWriteVisitor relCopyOnWriteVisitor) { + public ExpressionCopyOnWriteVisitor(RelCopyOnWriteVisitor relCopyOnWriteVisitor) { this.relCopyOnWriteVisitor = relCopyOnWriteVisitor; } - protected final RelCopyOnWriteVisitor getRelCopyOnWriteVisitor() { + protected final RelCopyOnWriteVisitor getRelCopyOnWriteVisitor() { return this.relCopyOnWriteVisitor; } @@ -29,160 +30,191 @@ public Optional visitLiteral(Expression.Literal literal) { } @Override - public Optional visit(Expression.NullLiteral expr) throws EXCEPTION { + public Optional visit(Expression.NullLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.BoolLiteral expr) throws EXCEPTION { + public Optional visit(Expression.BoolLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I8Literal expr) throws EXCEPTION { + public Optional visit(Expression.I8Literal expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I16Literal expr) throws EXCEPTION { + public Optional visit(Expression.I16Literal expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I32Literal expr) throws EXCEPTION { + public Optional visit(Expression.I32Literal expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.I64Literal expr) throws EXCEPTION { + public Optional visit(Expression.I64Literal expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.FP32Literal expr) throws EXCEPTION { + public Optional visit(Expression.FP32Literal expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.FP64Literal expr) throws EXCEPTION { + public Optional visit(Expression.FP64Literal expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.StrLiteral expr) throws EXCEPTION { + public Optional visit(Expression.StrLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.BinaryLiteral expr) throws EXCEPTION { + public Optional visit(Expression.BinaryLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.TimeLiteral expr) throws EXCEPTION { + public Optional visit(Expression.TimeLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.DateLiteral expr) throws EXCEPTION { + public Optional visit(Expression.DateLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.TimestampLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.TimestampLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.TimestampTZLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.TimestampTZLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.PrecisionTimestampLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.PrecisionTimestampLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.PrecisionTimestampTZLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.PrecisionTimestampTZLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.IntervalYearLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.IntervalYearLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.IntervalDayLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.IntervalDayLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.IntervalCompoundLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.IntervalCompoundLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.UUIDLiteral expr) throws EXCEPTION { + public Optional visit(Expression.UUIDLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.FixedCharLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.FixedCharLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.VarCharLiteral expr) throws EXCEPTION { + public Optional visit(Expression.VarCharLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.FixedBinaryLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.FixedBinaryLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.DecimalLiteral expr) throws EXCEPTION { + public Optional visit(Expression.DecimalLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.MapLiteral expr) throws EXCEPTION { + public Optional visit(Expression.MapLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.EmptyMapLiteral expr) throws EXCEPTION { + public Optional visit(Expression.EmptyMapLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.ListLiteral expr) throws EXCEPTION { + public Optional visit(Expression.ListLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.EmptyListLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.EmptyListLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.StructLiteral expr) throws EXCEPTION { + public Optional visit(Expression.StructLiteral expr, EmptyVisitationContext context) + throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.UserDefinedLiteral expr) throws EXCEPTION { + public Optional visit( + Expression.UserDefinedLiteral expr, EmptyVisitationContext context) throws E { return visitLiteral(expr); } @Override - public Optional visit(Expression.Switch expr) throws EXCEPTION { - var match = expr.match().accept(this); - var switchClauses = transformList(expr.switchClauses(), this::visitSwitchClause); - var defaultClause = expr.defaultClause().accept(this); + public Optional visit(Expression.Switch expr, EmptyVisitationContext context) + throws E { + var match = expr.match().accept(this, context); + var switchClauses = transformList(expr.switchClauses(), context, this::visitSwitchClause); + var defaultClause = expr.defaultClause().accept(this, context); if (allEmpty(match, switchClauses, defaultClause)) { return Optional.empty(); @@ -197,20 +229,21 @@ public Optional visit(Expression.Switch expr) throws EXCEPTION { } protected Optional visitSwitchClause( - Expression.SwitchClause switchClause) throws EXCEPTION { + Expression.SwitchClause switchClause, EmptyVisitationContext context) throws E { // This code does not visit the condition on the switch clause as that MUST be a Literal and the // visitor does not guarantee a Literal return type. If you wish to update the condition, // override this method. return switchClause .then() - .accept(this) + .accept(this, context) .map(then -> Expression.SwitchClause.builder().from(switchClause).then(then).build()); } @Override - public Optional visit(Expression.IfThen ifThen) throws EXCEPTION { - var ifClauses = transformList(ifThen.ifClauses(), this::visitIfClause); - var elseClause = ifThen.elseClause().accept(this); + public Optional visit(Expression.IfThen ifThen, EmptyVisitationContext context) + throws E { + var ifClauses = transformList(ifThen.ifClauses(), context, this::visitIfClause); + var elseClause = ifThen.elseClause().accept(this, context); if (allEmpty(ifClauses, elseClause)) { return Optional.empty(); @@ -223,10 +256,10 @@ public Optional visit(Expression.IfThen ifThen) throws EXCEPTION { .build()); } - protected Optional visitIfClause(Expression.IfClause ifClause) - throws EXCEPTION { - var condition = ifClause.condition().accept(this); - var then = ifClause.then().accept(this); + protected Optional visitIfClause( + Expression.IfClause ifClause, EmptyVisitationContext context) throws E { + var condition = ifClause.condition().accept(this, context); + var then = ifClause.then().accept(this, context); if (allEmpty(condition, then)) { return Optional.empty(); @@ -240,8 +273,9 @@ protected Optional visitIfClause(Expression.IfClause ifClau } @Override - public Optional visit(Expression.ScalarFunctionInvocation sfi) throws EXCEPTION { - return visitFunctionArguments(sfi.arguments()) + public Optional visit( + Expression.ScalarFunctionInvocation sfi, EmptyVisitationContext context) throws E { + return visitFunctionArguments(sfi.arguments(), context) .map( arguments -> Expression.ScalarFunctionInvocation.builder() @@ -251,10 +285,11 @@ public Optional visit(Expression.ScalarFunctionInvocation sfi) throw } @Override - public Optional visit(Expression.WindowFunctionInvocation wfi) throws EXCEPTION { - var arguments = visitFunctionArguments(wfi.arguments()); - var partitionBy = visitExprList(wfi.partitionBy()); - var sort = transformList(wfi.sort(), this::visitSortField); + public Optional visit( + Expression.WindowFunctionInvocation wfi, EmptyVisitationContext context) throws E { + var arguments = visitFunctionArguments(wfi.arguments(), context); + var partitionBy = visitExprList(wfi.partitionBy(), context); + var sort = transformList(wfi.sort(), context, this::visitSortField); if (allEmpty(arguments, partitionBy, sort)) { return Optional.empty(); @@ -269,16 +304,17 @@ public Optional visit(Expression.WindowFunctionInvocation wfi) throw } @Override - public Optional visit(Expression.Cast cast) throws EXCEPTION { + public Optional visit(Expression.Cast cast, EmptyVisitationContext context) throws E { return cast.input() - .accept(this) + .accept(this, context) .map(input -> Expression.Cast.builder().from(cast).input(input).build()); } @Override - public Optional visit(Expression.SingleOrList singleOrList) throws EXCEPTION { - var condition = singleOrList.condition().accept(this); - var options = visitExprList(singleOrList.options()); + public Optional visit( + Expression.SingleOrList singleOrList, EmptyVisitationContext context) throws E { + var condition = singleOrList.condition().accept(this, context); + var options = visitExprList(singleOrList.options(), context); if (allEmpty(condition, options)) { return Optional.empty(); @@ -292,10 +328,11 @@ public Optional visit(Expression.SingleOrList singleOrList) throws E } @Override - public Optional visit(Expression.MultiOrList multiOrList) throws EXCEPTION { - var conditions = visitExprList(multiOrList.conditions()); + public Optional visit( + Expression.MultiOrList multiOrList, EmptyVisitationContext context) throws E { + var conditions = visitExprList(multiOrList.conditions(), context); var optionCombinations = - transformList(multiOrList.optionCombinations(), this::visitMultiOrListRecord); + transformList(multiOrList.optionCombinations(), context, this::visitMultiOrListRecord); if (allEmpty(conditions, optionCombinations)) { return Optional.empty(); @@ -309,8 +346,8 @@ public Optional visit(Expression.MultiOrList multiOrList) throws EXC } protected Optional visitMultiOrListRecord( - Expression.MultiOrListRecord multiOrListRecord) throws EXCEPTION { - return visitExprList(multiOrListRecord.values()) + Expression.MultiOrListRecord multiOrListRecord, EmptyVisitationContext context) throws E { + return visitExprList(multiOrListRecord.values(), context) .map( values -> Expression.MultiOrListRecord.builder() @@ -320,8 +357,9 @@ protected Optional visitMultiOrListRecord( } @Override - public Optional visit(FieldReference fieldReference) throws EXCEPTION { - var inputExpression = visitOptionalExpression(fieldReference.inputExpression()); + public Optional visit(FieldReference fieldReference, EmptyVisitationContext context) + throws E { + var inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); @@ -330,26 +368,29 @@ public Optional visit(FieldReference fieldReference) throws EXCEPTIO } @Override - public Optional visit(Expression.SetPredicate setPredicate) throws EXCEPTION { + public Optional visit( + Expression.SetPredicate setPredicate, EmptyVisitationContext context) throws E { return setPredicate .tuples() - .accept(getRelCopyOnWriteVisitor()) + .accept(getRelCopyOnWriteVisitor(), context) .map(tuple -> Expression.SetPredicate.builder().from(setPredicate).tuples(tuple).build()); } @Override - public Optional visit(Expression.ScalarSubquery scalarSubquery) throws EXCEPTION { + public Optional visit( + Expression.ScalarSubquery scalarSubquery, EmptyVisitationContext context) throws E { return scalarSubquery .input() - .accept(getRelCopyOnWriteVisitor()) + .accept(getRelCopyOnWriteVisitor(), context) .map( input -> Expression.ScalarSubquery.builder().from(scalarSubquery).input(input).build()); } @Override - public Optional visit(Expression.InPredicate inPredicate) throws EXCEPTION { - var haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor()); - var needles = visitExprList(inPredicate.needles()); + public Optional visit( + Expression.InPredicate inPredicate, EmptyVisitationContext context) throws E { + var haystack = inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context); + var needles = visitExprList(inPredicate.needles(), context); if (allEmpty(haystack, needles)) { return Optional.empty(); @@ -364,37 +405,39 @@ public Optional visit(Expression.InPredicate inPredicate) throws EXC // utilities - protected Optional> visitExprList(List exprs) throws EXCEPTION { - return transformList(exprs, e -> e.accept(this)); + protected Optional> visitExprList( + List exprs, EmptyVisitationContext context) throws E { + return transformList(exprs, context, (e, c) -> e.accept(this, c)); } - private Optional visitOptionalExpression(Optional optExpr) - throws EXCEPTION { + private Optional visitOptionalExpression( + Optional optExpr, EmptyVisitationContext context) throws E { // not using optExpr.map to allow us to propagate the EXCEPTION nicely if (optExpr.isPresent()) { - return optExpr.get().accept(this); + return optExpr.get().accept(this, context); } return Optional.empty(); } - protected Optional> visitFunctionArguments(List funcArgs) - throws EXCEPTION { - return CopyOnWriteUtils.transformList( + protected Optional> visitFunctionArguments( + List funcArgs, EmptyVisitationContext context) throws E { + return CopyOnWriteUtils.transformList( funcArgs, - arg -> { + context, + (arg, c) -> { if (arg instanceof Expression expr) { - return expr.accept(this).flatMap(Optional::of); + return expr.accept(this, c).flatMap(Optional::of); } else { return Optional.empty(); } }); } - protected Optional visitSortField(Expression.SortField sortField) - throws EXCEPTION { + protected Optional visitSortField( + Expression.SortField sortField, EmptyVisitationContext context) throws E { return sortField .expr() - .accept(this) + .accept(this, context) .map(expr -> Expression.SortField.builder().from(sortField).expr(expr).build()); } } diff --git a/core/src/main/java/io/substrait/relation/ExtensionDdl.java b/core/src/main/java/io/substrait/relation/ExtensionDdl.java index bfc037b91..b95fc0c53 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionDdl.java +++ b/core/src/main/java/io/substrait/relation/ExtensionDdl.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable @@ -7,8 +8,9 @@ public abstract class ExtensionDdl extends AbstractDdlRel implements HasExtensio public abstract Extension.DdlExtensionObject getDetail(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExtensionDdl.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/ExtensionLeaf.java b/core/src/main/java/io/substrait/relation/ExtensionLeaf.java index 92ca41172..7b990ae19 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionLeaf.java +++ b/core/src/main/java/io/substrait/relation/ExtensionLeaf.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable @@ -8,8 +9,9 @@ public abstract class ExtensionLeaf extends ZeroInputRel { public abstract Extension.LeafRelDetail getDetail(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExtensionLeaf.Builder from(Extension.LeafRelDetail detail) { diff --git a/core/src/main/java/io/substrait/relation/ExtensionMulti.java b/core/src/main/java/io/substrait/relation/ExtensionMulti.java index 66b077d60..5ed3da08b 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionMulti.java +++ b/core/src/main/java/io/substrait/relation/ExtensionMulti.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; @@ -11,8 +12,9 @@ public abstract class ExtensionMulti extends AbstractRel { public abstract Extension.MultiRelDetail getDetail(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExtensionMulti.Builder from( diff --git a/core/src/main/java/io/substrait/relation/ExtensionSingle.java b/core/src/main/java/io/substrait/relation/ExtensionSingle.java index e6cfb4ed5..69edb97d3 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionSingle.java +++ b/core/src/main/java/io/substrait/relation/ExtensionSingle.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable @@ -8,8 +9,9 @@ public abstract class ExtensionSingle extends SingleInputRel { public abstract Extension.SingleRelDetail getDetail(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExtensionSingle.Builder from(Extension.SingleRelDetail detail, Rel input) { diff --git a/core/src/main/java/io/substrait/relation/ExtensionTable.java b/core/src/main/java/io/substrait/relation/ExtensionTable.java index 7857d71e5..5cbc4231e 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionTable.java +++ b/core/src/main/java/io/substrait/relation/ExtensionTable.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable @@ -8,8 +9,9 @@ public abstract class ExtensionTable extends AbstractReadRel { public abstract Extension.ExtensionTableDetail getDetail(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExtensionTable.Builder from(Extension.ExtensionTableDetail detail) { diff --git a/core/src/main/java/io/substrait/relation/ExtensionWrite.java b/core/src/main/java/io/substrait/relation/ExtensionWrite.java index 78659c766..db591453b 100644 --- a/core/src/main/java/io/substrait/relation/ExtensionWrite.java +++ b/core/src/main/java/io/substrait/relation/ExtensionWrite.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable @@ -7,8 +8,9 @@ public abstract class ExtensionWrite extends AbstractWriteRel implements HasExte public abstract Extension.WriteExtensionObject getDetail(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableExtensionWrite.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Fetch.java b/core/src/main/java/io/substrait/relation/Fetch.java index de2a1f23e..4c76453ea 100644 --- a/core/src/main/java/io/substrait/relation/Fetch.java +++ b/core/src/main/java/io/substrait/relation/Fetch.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.type.Type; +import io.substrait.util.VisitationContext; import java.util.OptionalLong; import org.immutables.value.Value; @@ -18,8 +19,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableFetch.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Filter.java b/core/src/main/java/io/substrait/relation/Filter.java index 3085891b5..fcf54d5dc 100644 --- a/core/src/main/java/io/substrait/relation/Filter.java +++ b/core/src/main/java/io/substrait/relation/Filter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.type.Type; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Immutable @@ -16,8 +17,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableFilter.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index 77b47fd26..490bd7315 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -4,6 +4,7 @@ import io.substrait.proto.JoinRel; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.Optional; import java.util.stream.Stream; import org.immutables.value.Value; @@ -90,8 +91,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableJoin.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/LocalFiles.java b/core/src/main/java/io/substrait/relation/LocalFiles.java index eeee54de7..3b91d2bbd 100644 --- a/core/src/main/java/io/substrait/relation/LocalFiles.java +++ b/core/src/main/java/io/substrait/relation/LocalFiles.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.relation.files.FileOrFiles; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -12,8 +13,9 @@ public abstract class LocalFiles extends AbstractReadRel { public abstract List getItems(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableLocalFiles.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/NamedDdl.java b/core/src/main/java/io/substrait/relation/NamedDdl.java index 15e2ce193..873e4b481 100644 --- a/core/src/main/java/io/substrait/relation/NamedDdl.java +++ b/core/src/main/java/io/substrait/relation/NamedDdl.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -8,8 +9,9 @@ public abstract class NamedDdl extends AbstractDdlRel { public abstract List getNames(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableNamedDdl.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/NamedScan.java b/core/src/main/java/io/substrait/relation/NamedScan.java index 142729282..225a5b27b 100644 --- a/core/src/main/java/io/substrait/relation/NamedScan.java +++ b/core/src/main/java/io/substrait/relation/NamedScan.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -9,8 +10,9 @@ public abstract class NamedScan extends AbstractReadRel { public abstract List getNames(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableNamedScan.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/NamedUpdate.java b/core/src/main/java/io/substrait/relation/NamedUpdate.java index 0dd1ff349..f17947c85 100644 --- a/core/src/main/java/io/substrait/relation/NamedUpdate.java +++ b/core/src/main/java/io/substrait/relation/NamedUpdate.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -9,8 +10,9 @@ public abstract class NamedUpdate extends AbstractUpdate { public abstract List getNames(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableNamedUpdate.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/NamedWrite.java b/core/src/main/java/io/substrait/relation/NamedWrite.java index 25abdd307..e46f9b3cb 100644 --- a/core/src/main/java/io/substrait/relation/NamedWrite.java +++ b/core/src/main/java/io/substrait/relation/NamedWrite.java @@ -1,5 +1,6 @@ package io.substrait.relation; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -8,8 +9,9 @@ public abstract class NamedWrite extends AbstractWriteRel implements HasExtensio public abstract List getNames(); @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableNamedWrite.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Project.java b/core/src/main/java/io/substrait/relation/Project.java index 7ba6b26b9..dbf3128d6 100644 --- a/core/src/main/java/io/substrait/relation/Project.java +++ b/core/src/main/java/io/substrait/relation/Project.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.List; import java.util.stream.Stream; import org.immutables.value.Value; @@ -23,8 +24,9 @@ public Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableProject.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Rel.java b/core/src/main/java/io/substrait/relation/Rel.java index 1472e9f4b..5f8af6312 100644 --- a/core/src/main/java/io/substrait/relation/Rel.java +++ b/core/src/main/java/io/substrait/relation/Rel.java @@ -4,6 +4,7 @@ import io.substrait.hint.Hint; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.List; import java.util.Optional; import java.util.stream.IntStream; @@ -45,5 +46,6 @@ public static Remap offset(int start, int length) { } } - O accept(RelVisitor visitor) throws E; + O accept( + RelVisitor visitor, C context) throws E; } diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index ac51627c1..87d99ba72 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -11,6 +11,7 @@ import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -21,34 +22,33 @@ * By default, no subtree substitution will be performed. However, if a visit method is overridden * to return a non-empty optional value, then that value will replace the relation in the tree. */ -public class RelCopyOnWriteVisitor - implements RelVisitor, EXCEPTION> { +public class RelCopyOnWriteVisitor + implements RelVisitor, EmptyVisitationContext, E> { - private final ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor; + private final ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor; public RelCopyOnWriteVisitor() { this.expressionCopyOnWriteVisitor = new ExpressionCopyOnWriteVisitor<>(this); } - public RelCopyOnWriteVisitor( - ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor) { + public RelCopyOnWriteVisitor(ExpressionCopyOnWriteVisitor expressionCopyOnWriteVisitor) { this.expressionCopyOnWriteVisitor = expressionCopyOnWriteVisitor; } public RelCopyOnWriteVisitor( - Function, ExpressionCopyOnWriteVisitor> fn) { + Function, ExpressionCopyOnWriteVisitor> fn) { this.expressionCopyOnWriteVisitor = fn.apply(this); } - protected ExpressionCopyOnWriteVisitor getExpressionCopyOnWriteVisitor() { + protected ExpressionCopyOnWriteVisitor getExpressionCopyOnWriteVisitor() { return expressionCopyOnWriteVisitor; } @Override - public Optional visit(Aggregate aggregate) throws EXCEPTION { - var input = aggregate.getInput().accept(this); - var groupings = transformList(aggregate.getGroupings(), this::visitGrouping); - var measures = transformList(aggregate.getMeasures(), this::visitMeasure); + public Optional visit(Aggregate aggregate, EmptyVisitationContext context) throws E { + var input = aggregate.getInput().accept(this, context); + var groupings = transformList(aggregate.getGroupings(), context, this::visitGrouping); + var measures = transformList(aggregate.getMeasures(), context, this::visitMeasure); if (allEmpty(input, groupings, measures)) { return Optional.empty(); @@ -62,15 +62,16 @@ public Optional visit(Aggregate aggregate) throws EXCEPTION { .build()); } - protected Optional visitGrouping(Aggregate.Grouping grouping) - throws EXCEPTION { - return visitExprList(grouping.getExpressions()) + protected Optional visitGrouping( + Aggregate.Grouping grouping, EmptyVisitationContext context) throws E { + return visitExprList(grouping.getExpressions(), context) .map(exprs -> Aggregate.Grouping.builder().from(grouping).expressions(exprs).build()); } - protected Optional visitMeasure(Aggregate.Measure measure) throws EXCEPTION { - var preMeasureFilter = visitOptionalExpression(measure.getPreMeasureFilter()); - var afi = visitAggregateFunction(measure.getFunction()); + protected Optional visitMeasure( + Aggregate.Measure measure, EmptyVisitationContext context) throws E { + var preMeasureFilter = visitOptionalExpression(measure.getPreMeasureFilter(), context); + var afi = visitAggregateFunction(measure.getFunction(), context); if (allEmpty(preMeasureFilter, afi)) { return Optional.empty(); @@ -84,9 +85,9 @@ protected Optional visitMeasure(Aggregate.Measure measure) th } protected Optional visitAggregateFunction( - AggregateFunctionInvocation afi) throws EXCEPTION { - var arguments = visitFunctionArguments(afi.arguments()); - var sort = transformList(afi.sort(), this::visitSortField); + AggregateFunctionInvocation afi, EmptyVisitationContext context) throws E { + var arguments = visitFunctionArguments(afi.arguments(), context); + var sort = transformList(afi.sort(), context, this::visitSortField); if (allEmpty(arguments, sort)) { return Optional.empty(); @@ -100,8 +101,8 @@ protected Optional visitAggregateFunction( } @Override - public Optional visit(EmptyScan emptyScan) throws EXCEPTION { - Optional filter = visitOptionalExpression(emptyScan.getFilter()); + public Optional visit(EmptyScan emptyScan, EmptyVisitationContext context) throws E { + Optional filter = visitOptionalExpression(emptyScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -114,17 +115,17 @@ public Optional visit(EmptyScan emptyScan) throws EXCEPTION { } @Override - public Optional visit(Fetch fetch) throws EXCEPTION { + public Optional visit(Fetch fetch, EmptyVisitationContext context) throws E { return fetch .getInput() - .accept(this) + .accept(this, context) .map(input -> Fetch.builder().from(fetch).input(input).build()); } @Override - public Optional visit(Filter filter) throws EXCEPTION { - var input = filter.getInput().accept(this); - var condition = filter.getCondition().accept(getExpressionCopyOnWriteVisitor()); + public Optional visit(Filter filter, EmptyVisitationContext context) throws E { + var input = filter.getInput().accept(this, context); + var condition = filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(input, condition)) { return Optional.empty(); @@ -138,11 +139,11 @@ public Optional visit(Filter filter) throws EXCEPTION { } @Override - public Optional visit(Join join) throws EXCEPTION { - var left = join.getLeft().accept(this); - var right = join.getRight().accept(this); - var condition = visitOptionalExpression(join.getCondition()); - var postFilter = visitOptionalExpression(join.getPostJoinFilter()); + public Optional visit(Join join, EmptyVisitationContext context) throws E { + var left = join.getLeft().accept(this, context); + var right = join.getRight().accept(this, context); + var condition = visitOptionalExpression(join.getCondition(), context); + var postFilter = visitOptionalExpression(join.getPostJoinFilter(), context); if (allEmpty(left, right, condition, postFilter)) { return Optional.empty(); @@ -158,14 +159,14 @@ public Optional visit(Join join) throws EXCEPTION { } @Override - public Optional visit(Set set) throws EXCEPTION { - return transformList(set.getInputs(), t -> t.accept(this)) + public Optional visit(Set set, EmptyVisitationContext context) throws E { + return transformList(set.getInputs(), context, (t, c) -> t.accept(this, c)) .map(s -> Set.builder().from(set).inputs(s).build()); } @Override - public Optional visit(NamedScan namedScan) throws EXCEPTION { - var filter = visitOptionalExpression(namedScan.getFilter()); + public Optional visit(NamedScan namedScan, EmptyVisitationContext context) throws E { + var filter = visitOptionalExpression(namedScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -175,8 +176,8 @@ public Optional visit(NamedScan namedScan) throws EXCEPTION { } @Override - public Optional visit(LocalFiles localFiles) throws EXCEPTION { - var filter = visitOptionalExpression(localFiles.getFilter()); + public Optional visit(LocalFiles localFiles, EmptyVisitationContext context) throws E { + var filter = visitOptionalExpression(localFiles.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -186,9 +187,9 @@ public Optional visit(LocalFiles localFiles) throws EXCEPTION { } @Override - public Optional visit(Project project) throws EXCEPTION { - var input = project.getInput().accept(this); - var expressions = visitExprList(project.getExpressions()); + public Optional visit(Project project, EmptyVisitationContext context) throws E { + var input = project.getInput().accept(this, context); + var expressions = visitExprList(project.getExpressions(), context); if (allEmpty(input, expressions)) { return Optional.empty(); @@ -202,39 +203,39 @@ public Optional visit(Project project) throws EXCEPTION { } @Override - public Optional visit(Expand expand) throws EXCEPTION { + public Optional visit(Expand expand, EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(NamedWrite write) throws EXCEPTION { + public Optional visit(NamedWrite write, EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(ExtensionWrite write) throws EXCEPTION { + public Optional visit(ExtensionWrite write, EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(NamedDdl ddl) throws EXCEPTION { + public Optional visit(NamedDdl ddl, EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(ExtensionDdl ddl) throws EXCEPTION { + public Optional visit(ExtensionDdl ddl, EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(NamedUpdate update) throws EXCEPTION { + public Optional visit(NamedUpdate update, EmptyVisitationContext context) throws E { throw new UnsupportedOperationException(); } @Override - public Optional visit(Sort sort) throws EXCEPTION { - var input = sort.getInput().accept(this); - var sortFields = transformList(sort.getSortFields(), this::visitSortField); + public Optional visit(Sort sort, EmptyVisitationContext context) throws E { + var input = sort.getInput().accept(this, context); + var sortFields = transformList(sort.getSortFields(), context, this::visitSortField); if (allEmpty(input, sortFields)) { return Optional.empty(); @@ -248,9 +249,9 @@ public Optional visit(Sort sort) throws EXCEPTION { } @Override - public Optional visit(Cross cross) throws EXCEPTION { - var left = cross.getLeft().accept(this); - var right = cross.getRight().accept(this); + public Optional visit(Cross cross, EmptyVisitationContext context) throws E { + var left = cross.getLeft().accept(this, context); + var right = cross.getRight().accept(this, context); if (allEmpty(left, right)) { return Optional.empty(); @@ -264,8 +265,9 @@ public Optional visit(Cross cross) throws EXCEPTION { } @Override - public Optional visit(VirtualTableScan virtualTableScan) throws EXCEPTION { - var filter = visitOptionalExpression(virtualTableScan.getFilter()); + public Optional visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) + throws E { + var filter = visitOptionalExpression(virtualTableScan.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -278,27 +280,30 @@ public Optional visit(VirtualTableScan virtualTableScan) throws EXCEPTION { } @Override - public Optional visit(ExtensionLeaf extensionLeaf) throws EXCEPTION { + public Optional visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) throws E { return Optional.empty(); } @Override - public Optional visit(ExtensionSingle extensionSingle) throws EXCEPTION { + public Optional visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + throws E { return extensionSingle .getInput() - .accept(this) + .accept(this, context) .map(input -> ExtensionSingle.builder().from(extensionSingle).input(input).build()); } @Override - public Optional visit(ExtensionMulti extensionMulti) throws EXCEPTION { - return transformList(extensionMulti.getInputs(), rel -> rel.accept(this)) + public Optional visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + throws E { + return transformList(extensionMulti.getInputs(), context, (rel, c) -> rel.accept(this, c)) .map(inputs -> ExtensionMulti.builder().from(extensionMulti).inputs(inputs).build()); } @Override - public Optional visit(ExtensionTable extensionTable) throws EXCEPTION { - var filter = visitOptionalExpression(extensionTable.getFilter()); + public Optional visit(ExtensionTable extensionTable, EmptyVisitationContext context) + throws E { + var filter = visitOptionalExpression(extensionTable.getFilter(), context); if (allEmpty(filter)) { return Optional.empty(); @@ -311,12 +316,12 @@ public Optional visit(ExtensionTable extensionTable) throws EXCEPTION { } @Override - public Optional visit(HashJoin hashJoin) throws EXCEPTION { - var left = hashJoin.getLeft().accept(this); - var right = hashJoin.getRight().accept(this); - var leftKeys = transformList(hashJoin.getLeftKeys(), this::visitFieldReference); - var rightKeys = transformList(hashJoin.getRightKeys(), this::visitFieldReference); - var postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter()); + public Optional visit(HashJoin hashJoin, EmptyVisitationContext context) throws E { + var left = hashJoin.getLeft().accept(this, context); + var right = hashJoin.getRight().accept(this, context); + var leftKeys = transformList(hashJoin.getLeftKeys(), context, this::visitFieldReference); + var rightKeys = transformList(hashJoin.getRightKeys(), context, this::visitFieldReference); + var postFilter = visitOptionalExpression(hashJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { return Optional.empty(); @@ -333,12 +338,12 @@ public Optional visit(HashJoin hashJoin) throws EXCEPTION { } @Override - public Optional visit(MergeJoin mergeJoin) throws EXCEPTION { - var left = mergeJoin.getLeft().accept(this); - var right = mergeJoin.getRight().accept(this); - var leftKeys = transformList(mergeJoin.getLeftKeys(), this::visitFieldReference); - var rightKeys = transformList(mergeJoin.getRightKeys(), this::visitFieldReference); - var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter()); + public Optional visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws E { + var left = mergeJoin.getLeft().accept(this, context); + var right = mergeJoin.getRight().accept(this, context); + var leftKeys = transformList(mergeJoin.getLeftKeys(), context, this::visitFieldReference); + var rightKeys = transformList(mergeJoin.getRightKeys(), context, this::visitFieldReference); + var postFilter = visitOptionalExpression(mergeJoin.getPostJoinFilter(), context); if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { return Optional.empty(); @@ -355,10 +360,12 @@ public Optional visit(MergeJoin mergeJoin) throws EXCEPTION { } @Override - public Optional visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { - var left = nestedLoopJoin.getLeft().accept(this); - var right = nestedLoopJoin.getRight().accept(this); - var condition = nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor()); + public Optional visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) + throws E { + var left = nestedLoopJoin.getLeft().accept(this, context); + var right = nestedLoopJoin.getRight().accept(this, context); + var condition = + nestedLoopJoin.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); if (allEmpty(left, right, condition)) { return Optional.empty(); @@ -373,14 +380,18 @@ public Optional visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION { } @Override - public Optional visit(ConsistentPartitionWindow consistentPartitionWindow) throws EXCEPTION { + public Optional visit( + ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + throws E { var windowFunctions = - transformList(consistentPartitionWindow.getWindowFunctions(), this::visitWindowRelFunction); + transformList( + consistentPartitionWindow.getWindowFunctions(), context, this::visitWindowRelFunction); var partitionExpressions = transformList( consistentPartitionWindow.getPartitionExpressions(), - t -> t.accept(getExpressionCopyOnWriteVisitor())); - var sorts = transformList(consistentPartitionWindow.getSorts(), this::visitSortField); + context, + (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c)); + var sorts = transformList(consistentPartitionWindow.getSorts(), context, this::visitSortField); if (allEmpty(windowFunctions, partitionExpressions, sorts)) { return Optional.empty(); @@ -397,9 +408,10 @@ public Optional visit(ConsistentPartitionWindow consistentPartitionWindow) } protected Optional visitWindowRelFunction( - ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation) - throws EXCEPTION { - var functionArgs = visitFunctionArguments(windowRelFunctionInvocation.arguments()); + ConsistentPartitionWindow.WindowRelFunctionInvocation windowRelFunctionInvocation, + EmptyVisitationContext context) + throws E { + var functionArgs = visitFunctionArguments(windowRelFunctionInvocation.arguments(), context); if (allEmpty(functionArgs)) { return Optional.empty(); @@ -414,13 +426,14 @@ protected Optional visitW // utilities - protected Optional> visitExprList(List exprs) throws EXCEPTION { - return transformList(exprs, t -> t.accept(getExpressionCopyOnWriteVisitor())); + protected Optional> visitExprList( + List exprs, EmptyVisitationContext context) throws E { + return transformList(exprs, context, (t, c) -> t.accept(getExpressionCopyOnWriteVisitor(), c)); } - public Optional visitFieldReference(FieldReference fieldReference) - throws EXCEPTION { - var inputExpression = visitOptionalExpression(fieldReference.inputExpression()); + public Optional visitFieldReference( + FieldReference fieldReference, EmptyVisitationContext context) throws E { + var inputExpression = visitOptionalExpression(fieldReference.inputExpression(), context); if (allEmpty(inputExpression)) { return Optional.empty(); } @@ -428,13 +441,14 @@ public Optional visitFieldReference(FieldReference fieldReferenc return Optional.of(FieldReference.builder().inputExpression(inputExpression).build()); } - protected Optional> visitFunctionArguments(List funcArgs) - throws EXCEPTION { - return CopyOnWriteUtils.transformList( + protected Optional> visitFunctionArguments( + List funcArgs, EmptyVisitationContext context) throws E { + return CopyOnWriteUtils.transformList( funcArgs, - arg -> { + context, + (arg, c) -> { if (arg instanceof Expression expr) { - return expr.accept(getExpressionCopyOnWriteVisitor()) + return expr.accept(getExpressionCopyOnWriteVisitor(), c) .flatMap(Optional::of); } else { return Optional.empty(); @@ -442,19 +456,19 @@ protected Optional> visitFunctionArguments(List f }); } - protected Optional visitSortField(Expression.SortField sortField) - throws EXCEPTION { + protected Optional visitSortField( + Expression.SortField sortField, EmptyVisitationContext context) throws E { return sortField .expr() - .accept(getExpressionCopyOnWriteVisitor()) + .accept(getExpressionCopyOnWriteVisitor(), context) .map(expr -> Expression.SortField.builder().from(sortField).expr(expr).build()); } - private Optional visitOptionalExpression(Optional optExpr) - throws EXCEPTION { + private Optional visitOptionalExpression( + Optional optExpr, EmptyVisitationContext context) throws E { // not using optExpr.map to allow us to propagate the THROWABLE nicely if (optExpr.isPresent()) { - return optExpr.get().accept(getExpressionCopyOnWriteVisitor()); + return optExpr.get().accept(getExpressionCopyOnWriteVisitor(), context); } return Optional.empty(); } diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 32ad86ede..c746c48d3 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -40,13 +40,15 @@ import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; +import io.substrait.util.EmptyVisitationContext; import java.util.Collection; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; /** Converts from {@link io.substrait.relation.Rel} to {@link io.substrait.proto.Rel} */ -public class RelProtoConverter implements RelVisitor { +public class RelProtoConverter + implements RelVisitor { protected final ExpressionProtoConverter exprProtoConverter; protected final TypeProtoConverter typeProtoConverter; @@ -75,7 +77,7 @@ public io.substrait.proto.RelRoot toProto(Plan.Root relRoot) { } public io.substrait.proto.Rel toProto(io.substrait.relation.Rel rel) { - return rel.accept(this); + return rel.accept(this, null); } protected io.substrait.proto.Expression toProto(io.substrait.expression.Expression expression) { @@ -108,7 +110,7 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel } @Override - public Rel visit(Aggregate aggregate) throws RuntimeException { + public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { var builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) @@ -134,7 +136,7 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .setOutputType(toProto(measure.getFunction().getType())) .addAllArguments( IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) + .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null)) .collect(Collectors.toList())) .addAllSorts(toProtoS(measure.getFunction().sort())) .setFunctionReference( @@ -157,7 +159,7 @@ private AggregateRel.Grouping toProto(Aggregate.Grouping grouping) { } @Override - public Rel visit(EmptyScan emptyScan) throws RuntimeException { + public Rel visit(EmptyScan emptyScan, EmptyVisitationContext context) throws RuntimeException { return Rel.newBuilder() .setRead( ReadRel.newBuilder() @@ -169,7 +171,7 @@ public Rel visit(EmptyScan emptyScan) throws RuntimeException { } @Override - public Rel visit(Fetch fetch) throws RuntimeException { + public Rel visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { var builder = FetchRel.newBuilder() .setCommon(common(fetch)) @@ -183,7 +185,7 @@ public Rel visit(Fetch fetch) throws RuntimeException { } @Override - public Rel visit(Filter filter) throws RuntimeException { + public Rel visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { var builder = FilterRel.newBuilder() .setCommon(common(filter)) @@ -195,7 +197,7 @@ public Rel visit(Filter filter) throws RuntimeException { } @Override - public Rel visit(Join join) throws RuntimeException { + public Rel visit(Join join, EmptyVisitationContext context) throws RuntimeException { var builder = JoinRel.newBuilder() .setCommon(common(join)) @@ -212,7 +214,7 @@ public Rel visit(Join join) throws RuntimeException { } @Override - public Rel visit(Set set) throws RuntimeException { + public Rel visit(Set set, EmptyVisitationContext context) throws RuntimeException { var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto()); set.getInputs() .forEach( @@ -225,7 +227,7 @@ public Rel visit(Set set) throws RuntimeException { } @Override - public Rel visit(NamedScan namedScan) throws RuntimeException { + public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws RuntimeException { var builder = ReadRel.newBuilder() .setCommon(common(namedScan)) @@ -240,7 +242,7 @@ public Rel visit(NamedScan namedScan) throws RuntimeException { } @Override - public Rel visit(LocalFiles localFiles) throws RuntimeException { + public Rel visit(LocalFiles localFiles, EmptyVisitationContext context) throws RuntimeException { var builder = ReadRel.newBuilder() .setCommon(common(localFiles)) @@ -260,7 +262,8 @@ public Rel visit(LocalFiles localFiles) throws RuntimeException { } @Override - public Rel visit(ExtensionTable extensionTable) throws RuntimeException { + public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context) + throws RuntimeException { ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder().setDetail(extensionTable.getDetail().toProto(this)); var builder = @@ -274,7 +277,7 @@ public Rel visit(ExtensionTable extensionTable) throws RuntimeException { } @Override - public Rel visit(HashJoin hashJoin) throws RuntimeException { + public Rel visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { var builder = HashJoinRel.newBuilder() .setCommon(common(hashJoin)) @@ -299,7 +302,7 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException { } @Override - public Rel visit(MergeJoin mergeJoin) throws RuntimeException { + public Rel visit(MergeJoin mergeJoin, EmptyVisitationContext context) throws RuntimeException { var builder = MergeJoinRel.newBuilder() .setCommon(common(mergeJoin)) @@ -324,7 +327,8 @@ public Rel visit(MergeJoin mergeJoin) throws RuntimeException { } @Override - public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { + public Rel visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) + throws RuntimeException { var builder = NestedLoopJoinRel.newBuilder() .setCommon(common(nestedLoopJoin)) @@ -338,7 +342,9 @@ public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException { } @Override - public Rel visit(ConsistentPartitionWindow consistentPartitionWindow) throws RuntimeException { + public Rel visit( + ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + throws RuntimeException { var builder = ConsistentPartitionWindowRel.newBuilder() .setCommon(common(consistentPartitionWindow)) @@ -357,7 +363,7 @@ public Rel visit(ConsistentPartitionWindow consistentPartitionWindow) throws Run } @Override - public Rel visit(NamedWrite write) throws RuntimeException { + public Rel visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { var builder = WriteRel.newBuilder() .setCommon(common(write)) @@ -372,7 +378,7 @@ public Rel visit(NamedWrite write) throws RuntimeException { } @Override - public Rel visit(ExtensionWrite write) throws RuntimeException { + public Rel visit(ExtensionWrite write, EmptyVisitationContext context) throws RuntimeException { var builder = WriteRel.newBuilder() .setCommon(common(write)) @@ -388,7 +394,7 @@ public Rel visit(ExtensionWrite write) throws RuntimeException { } @Override - public Rel visit(NamedDdl ddl) throws RuntimeException { + public Rel visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { var builder = DdlRel.newBuilder() .setCommon(common(ddl)) @@ -405,7 +411,7 @@ public Rel visit(NamedDdl ddl) throws RuntimeException { } @Override - public Rel visit(ExtensionDdl ddl) throws RuntimeException { + public Rel visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { var builder = DdlRel.newBuilder() .setCommon(common(ddl)) @@ -422,7 +428,8 @@ public Rel visit(ExtensionDdl ddl) throws RuntimeException { return Rel.newBuilder().setDdl(builder).build(); } - public Rel visit(NamedUpdate update) throws RuntimeException { + @Override + public Rel visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { var builder = UpdateRel.newBuilder() .setNamedTable(NamedTable.newBuilder().addAllNames(update.getNames())) @@ -456,7 +463,7 @@ private List toProtoWindowRelFun var arguments = IntStream.range(0, args.size()) - .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) + .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor, null)) .collect(Collectors.toList()); var options = f.options().stream() @@ -479,7 +486,7 @@ private List toProtoWindowRelFun } @Override - public Rel visit(Project project) throws RuntimeException { + public Rel visit(Project project, EmptyVisitationContext context) throws RuntimeException { var builder = ProjectRel.newBuilder() .setCommon(common(project)) @@ -491,7 +498,7 @@ public Rel visit(Project project) throws RuntimeException { } @Override - public Rel visit(Expand expand) throws RuntimeException { + public Rel visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { var builder = ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput())); @@ -521,7 +528,7 @@ public Rel visit(Expand expand) throws RuntimeException { } @Override - public Rel visit(Sort sort) throws RuntimeException { + public Rel visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { var builder = SortRel.newBuilder() .setCommon(common(sort)) @@ -533,7 +540,7 @@ public Rel visit(Sort sort) throws RuntimeException { } @Override - public Rel visit(Cross cross) throws RuntimeException { + public Rel visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { var builder = CrossRel.newBuilder() .setCommon(common(cross)) @@ -545,7 +552,8 @@ public Rel visit(Cross cross) throws RuntimeException { } @Override - public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException { + public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) + throws RuntimeException { var builder = ReadRel.newBuilder() .setCommon(common(virtualTableScan)) @@ -567,7 +575,8 @@ public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException { } @Override - public Rel visit(ExtensionLeaf extensionLeaf) throws RuntimeException { + public Rel visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) + throws RuntimeException { var builder = ExtensionLeafRel.newBuilder() .setCommon(common(extensionLeaf)) @@ -576,7 +585,8 @@ public Rel visit(ExtensionLeaf extensionLeaf) throws RuntimeException { } @Override - public Rel visit(ExtensionSingle extensionSingle) throws RuntimeException { + public Rel visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + throws RuntimeException { var builder = ExtensionSingleRel.newBuilder() .setCommon(common(extensionSingle)) @@ -586,7 +596,8 @@ public Rel visit(ExtensionSingle extensionSingle) throws RuntimeException { } @Override - public Rel visit(ExtensionMulti extensionMulti) throws RuntimeException { + public Rel visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + throws RuntimeException { List inputs = extensionMulti.getInputs().stream().map(this::toProto).collect(Collectors.toList()); var builder = diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index 7fdefa95d..23ce99fea 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -3,57 +3,58 @@ import io.substrait.relation.physical.HashJoin; import io.substrait.relation.physical.MergeJoin; import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.util.VisitationContext; -public interface RelVisitor { - OUTPUT visit(Aggregate aggregate) throws EXCEPTION; +public interface RelVisitor { + O visit(Aggregate aggregate, C context) throws E; - OUTPUT visit(EmptyScan emptyScan) throws EXCEPTION; + O visit(EmptyScan emptyScan, C context) throws E; - OUTPUT visit(Fetch fetch) throws EXCEPTION; + O visit(Fetch fetch, C context) throws E; - OUTPUT visit(Filter filter) throws EXCEPTION; + O visit(Filter filter, C context) throws E; - OUTPUT visit(Join join) throws EXCEPTION; + O visit(Join join, C context) throws E; - OUTPUT visit(Set set) throws EXCEPTION; + O visit(Set set, C context) throws E; - OUTPUT visit(NamedScan namedScan) throws EXCEPTION; + O visit(NamedScan namedScan, C context) throws E; - OUTPUT visit(LocalFiles localFiles) throws EXCEPTION; + O visit(LocalFiles localFiles, C context) throws E; - OUTPUT visit(Project project) throws EXCEPTION; + O visit(Project project, C context) throws E; - OUTPUT visit(Expand expand) throws EXCEPTION; + O visit(Expand expand, C context) throws E; - OUTPUT visit(Sort sort) throws EXCEPTION; + O visit(Sort sort, C context) throws E; - OUTPUT visit(Cross cross) throws EXCEPTION; + O visit(Cross cross, C context) throws E; - OUTPUT visit(VirtualTableScan virtualTableScan) throws EXCEPTION; + O visit(VirtualTableScan virtualTableScan, C context) throws E; - OUTPUT visit(ExtensionLeaf extensionLeaf) throws EXCEPTION; + O visit(ExtensionLeaf extensionLeaf, C context) throws E; - OUTPUT visit(ExtensionSingle extensionSingle) throws EXCEPTION; + O visit(ExtensionSingle extensionSingle, C context) throws E; - OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION; + O visit(ExtensionMulti extensionMulti, C context) throws E; - OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION; + O visit(ExtensionTable extensionTable, C context) throws E; - OUTPUT visit(HashJoin hashJoin) throws EXCEPTION; + O visit(HashJoin hashJoin, C context) throws E; - OUTPUT visit(MergeJoin mergeJoin) throws EXCEPTION; + O visit(MergeJoin mergeJoin, C context) throws E; - OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION; + O visit(NestedLoopJoin nestedLoopJoin, C context) throws E; - OUTPUT visit(ConsistentPartitionWindow consistentPartitionWindow) throws EXCEPTION; + O visit(ConsistentPartitionWindow consistentPartitionWindow, C context) throws E; - OUTPUT visit(NamedWrite write) throws EXCEPTION; + O visit(NamedWrite write, C context) throws E; - OUTPUT visit(ExtensionWrite write) throws EXCEPTION; + O visit(ExtensionWrite write, C context) throws E; - OUTPUT visit(NamedDdl ddl) throws EXCEPTION; + O visit(NamedDdl ddl, C context) throws E; - OUTPUT visit(ExtensionDdl ddl) throws EXCEPTION; + O visit(ExtensionDdl ddl, C context) throws E; - OUTPUT visit(NamedUpdate update) throws EXCEPTION; + O visit(NamedUpdate update, C context) throws E; } diff --git a/core/src/main/java/io/substrait/relation/Set.java b/core/src/main/java/io/substrait/relation/Set.java index a1623ad1a..697cda1be 100644 --- a/core/src/main/java/io/substrait/relation/Set.java +++ b/core/src/main/java/io/substrait/relation/Set.java @@ -3,6 +3,7 @@ import io.substrait.proto.SetRel; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -131,8 +132,9 @@ private Type.Struct coalesceNullabilityIntersection(Type.Struct first, List O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableSet.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/Sort.java b/core/src/main/java/io/substrait/relation/Sort.java index d62ba3117..09bc01860 100644 --- a/core/src/main/java/io/substrait/relation/Sort.java +++ b/core/src/main/java/io/substrait/relation/Sort.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.type.Type; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -17,8 +18,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableSort.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index 88e78d7ff..6eb7a361d 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -3,6 +3,7 @@ import io.substrait.expression.Expression; import io.substrait.type.Type; import io.substrait.type.TypeVisitor; +import io.substrait.util.VisitationContext; import java.util.List; import org.immutables.value.Value; @@ -37,8 +38,9 @@ protected void check() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableVirtualTableScan.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index 6d0e68f8a..1d9cc3c54 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -8,6 +8,7 @@ import io.substrait.relation.RelVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.List; import java.util.Optional; import java.util.stream.Stream; @@ -75,8 +76,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableHashJoin.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java index 5435a4c2a..4f7facd32 100644 --- a/core/src/main/java/io/substrait/relation/physical/MergeJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/MergeJoin.java @@ -8,6 +8,7 @@ import io.substrait.relation.RelVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.List; import java.util.Optional; import java.util.stream.Stream; @@ -75,8 +76,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableMergeJoin.Builder builder() { diff --git a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java index 722fdb471..233aff176 100644 --- a/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/NestedLoopJoin.java @@ -7,6 +7,7 @@ import io.substrait.relation.RelVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import io.substrait.util.VisitationContext; import java.util.stream.Stream; import org.immutables.value.Value; @@ -69,8 +70,9 @@ protected Type.Struct deriveRecordType() { } @Override - public O accept(RelVisitor visitor) throws E { - return visitor.visit(this); + public O accept( + RelVisitor visitor, C context) throws E { + return visitor.visit(this, context); } public static ImmutableNestedLoopJoin.Builder builder() { diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 9d8844dbc..fc87708f4 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -5,6 +5,7 @@ import io.substrait.function.NullableType; import io.substrait.function.ParameterizedType; import io.substrait.function.TypeExpression; +import io.substrait.util.VisitationContext; import org.immutables.value.Value; @Value.Enclosing @@ -18,9 +19,10 @@ public static TypeCreator withNullability(boolean nullable) { R accept(final TypeVisitor typeVisitor) throws E; @Override - default R accept( - SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor) throws E { - return fnArgVisitor.visitType(fnDef, argIdx, this); + default R accept( + SimpleExtension.Function fnDef, int argIdx, FuncArgVisitor fnArgVisitor, C context) + throws E { + return fnArgVisitor.visitType(fnDef, argIdx, this, context); } @Value.Immutable diff --git a/core/src/main/java/io/substrait/util/EmptyVisitationContext.java b/core/src/main/java/io/substrait/util/EmptyVisitationContext.java new file mode 100644 index 000000000..acdf626c2 --- /dev/null +++ b/core/src/main/java/io/substrait/util/EmptyVisitationContext.java @@ -0,0 +1,3 @@ +package io.substrait.util; + +public class EmptyVisitationContext implements VisitationContext {} diff --git a/core/src/main/java/io/substrait/util/VisitationContext.java b/core/src/main/java/io/substrait/util/VisitationContext.java new file mode 100644 index 000000000..2f783c783 --- /dev/null +++ b/core/src/main/java/io/substrait/util/VisitationContext.java @@ -0,0 +1,3 @@ +package io.substrait.util; + +public interface VisitationContext {} diff --git a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java index 453560e50..22c54ecca 100644 --- a/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/GenericRoundtripTest.java @@ -57,7 +57,7 @@ public void roundtripTest(Method m, List paramInst, UnsupportedTypeGener null, EMPTY_TYPE, new ProtoRelConverter(new ExtensionCollector(), defaultExtensionCollection)); - assertEquals(val, from.from(val.accept(to))); + assertEquals(val, from.from(val.accept(to, null))); } // Parametrized case generator diff --git a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java index 5d0a7e465..f69f5d2fa 100644 --- a/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/IfThenRoundtripTest.java @@ -26,7 +26,7 @@ void ifThenNotNullable() { var to = new ExpressionProtoConverter(null, null); var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(ifRel, from.from(ifRel.accept(to))); + assertEquals(ifRel, from.from(ifRel.accept(to, null))); } @Test @@ -40,6 +40,6 @@ void ifThenNullable() { var to = new ExpressionProtoConverter(null, null); var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(ifRel, from.from(ifRel.accept(to))); + assertEquals(ifRel, from.from(ifRel.accept(to, null))); } } diff --git a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java index 8686e8660..5e1d58cf1 100644 --- a/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java @@ -17,6 +17,6 @@ void decimal() { var val = ExpressionCreator.decimal(false, BigDecimal.TEN, 10, 2); var to = new ExpressionProtoConverter(null, null); var from = new ProtoExpressionConverter(null, null, EMPTY_TYPE, protoRelConverter); - assertEquals(val, from.from(val.accept(to))); + assertEquals(val, from.from(val.accept(to, null))); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java index 416a5d235..2f6b2d6bd 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java @@ -6,6 +6,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelCopyOnWriteVisitor; import io.substrait.type.NamedStruct; +import io.substrait.util.EmptyVisitationContext; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -78,13 +79,13 @@ private TableGatherer() { */ public static Map, NamedStruct> gatherTables(Rel rootRel) { var visitor = new TableGatherer(); - rootRel.accept(visitor); + rootRel.accept(visitor, null); return visitor.tableMap; } @Override - public Optional visit(NamedScan namedScan) { - super.visit(namedScan); + public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { + super.visit(namedScan, context); List tableName = namedScan.getNames(); if (tableMap.containsKey(tableName)) { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 4188f527f..999acdf17 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import static io.substrait.isthmus.SqlConverterBase.EXTENSION_COLLECTION; import static io.substrait.isthmus.SqlToSubstrait.EXTENSION_COLLECTION; import com.google.common.collect.ImmutableList; @@ -22,6 +23,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.Set; import io.substrait.relation.Sort; +import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -56,7 +58,8 @@ * RelVisitor to convert Substrait Rel plan to Calcite RelNode plan. Unsupported Rel node will call * visitFallback and throw UnsupportedOperationException. */ -public class SubstraitRelNodeConverter extends AbstractRelVisitor { +public class SubstraitRelNodeConverter + extends AbstractRelVisitor { protected final RelDataTypeFactory typeFactory; @@ -133,30 +136,34 @@ public static RelNode convert( return relRoot.accept( new SubstraitRelNodeConverter( - EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder)); + EXTENSION_COLLECTION, relOptCluster.getTypeFactory(), relBuilder), + null); } @Override - public RelNode visit(Filter filter) throws RuntimeException { - RelNode input = filter.getInput().accept(this); - RexNode filterCondition = filter.getCondition().accept(expressionRexConverter); + public RelNode visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { + RelNode input = filter.getInput().accept(this, context); + RexNode filterCondition = filter.getCondition().accept(expressionRexConverter, context); RelNode node = relBuilder.push(input).filter(filterCondition).build(); return applyRemap(node, filter.getRemap()); } @Override - public RelNode visit(NamedScan namedScan) throws RuntimeException { + public RelNode visit(NamedScan namedScan, EmptyVisitationContext context) + throws RuntimeException { RelNode node = relBuilder.scan(namedScan.getNames()).build(); return applyRemap(node, namedScan.getRemap()); } @Override - public RelNode visit(LocalFiles localFiles) throws RuntimeException { - return visitFallback(localFiles); + public RelNode visit(LocalFiles localFiles, EmptyVisitationContext context) + throws RuntimeException { + return visitFallback(localFiles, context); } @Override - public RelNode visit(EmptyScan emptyScan) throws RuntimeException { + public RelNode visit(EmptyScan emptyScan, EmptyVisitationContext context) + throws RuntimeException { RelDataType rowType = typeConverter.toCalcite(relBuilder.getTypeFactory(), emptyScan.getInitialSchema().struct()); RelNode node = LogicalValues.create(relBuilder.getCluster(), rowType, ImmutableList.of()); @@ -164,15 +171,14 @@ public RelNode visit(EmptyScan emptyScan) throws RuntimeException { } @Override - public RelNode visit(Project project) throws RuntimeException { - RelNode child = project.getInput().accept(this); - + public RelNode visit(Project project, EmptyVisitationContext context) throws RuntimeException { + RelNode child = project.getInput().accept(this, context); Stream directOutputs = IntStream.range(0, child.getRowType().getFieldCount()) .mapToObj(fieldIndex -> rexBuilder.makeInputRef(child, fieldIndex)); Stream exprs = - project.getExpressions().stream().map(expr -> expr.accept(expressionRexConverter)); + project.getExpressions().stream().map(expr -> expr.accept(expressionRexConverter, context)); List rexExprs = Stream.concat(directOutputs, exprs).collect(java.util.stream.Collectors.toList()); @@ -182,9 +188,9 @@ public RelNode visit(Project project) throws RuntimeException { } @Override - public RelNode visit(Cross cross) throws RuntimeException { - RelNode left = cross.getLeft().accept(this); - RelNode right = cross.getRight().accept(this); + public RelNode visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { + RelNode left = cross.getLeft().accept(this, context); + RelNode right = cross.getRight().accept(this, context); // Calcite represents CROSS JOIN as the equivalent INNER JOIN with true condition RelNode node = relBuilder.push(left).push(right).join(JoinRelType.INNER, relBuilder.literal(true)).build(); @@ -192,12 +198,12 @@ public RelNode visit(Cross cross) throws RuntimeException { } @Override - public RelNode visit(Join join) throws RuntimeException { - var left = join.getLeft().accept(this); - var right = join.getRight().accept(this); - var condition = + public RelNode visit(Join join, EmptyVisitationContext context) throws RuntimeException { + RelNode left = join.getLeft().accept(this, context); + RelNode right = join.getRight().accept(this, context); + RexNode condition = join.getCondition() - .map(c -> c.accept(expressionRexConverter)) + .map(c -> c.accept(expressionRexConverter, context)) .orElse(relBuilder.literal(true)); var joinType = switch (join.getJoinType()) { @@ -219,12 +225,12 @@ public RelNode visit(Join join) throws RuntimeException { } @Override - public RelNode visit(Set set) throws RuntimeException { + public RelNode visit(Set set, EmptyVisitationContext context) throws RuntimeException { int numInputs = set.getInputs().size(); set.getInputs() .forEach( input -> { - relBuilder.push(input.accept(this)); + relBuilder.push(input.accept(this, context)); }); // TODO: MINUS_MULTISET and INTERSECTION_PRIMARY mappings are set to be removed as they do not // correspond to the Calcite relations they are associated with. They are retained for now @@ -247,20 +253,21 @@ public RelNode visit(Set set) throws RuntimeException { } @Override - public RelNode visit(Aggregate aggregate) throws RuntimeException { + public RelNode visit(Aggregate aggregate, EmptyVisitationContext context) + throws RuntimeException { if (!PreCalciteAggregateValidator.isValidCalciteAggregate(aggregate)) { aggregate = PreCalciteAggregateValidator.PreCalciteAggregateTransformer .transformToValidCalciteAggregate(aggregate); } - RelNode child = aggregate.getInput().accept(this); + RelNode child = aggregate.getInput().accept(this, context); var groupExprLists = aggregate.getGroupings().stream() .map( gr -> gr.getExpressions().stream() - .map(expr -> expr.accept(expressionRexConverter)) + .map(expr -> expr.accept(expressionRexConverter, context)) .collect(java.util.stream.Collectors.toList())) .collect(java.util.stream.Collectors.toList()); List groupExprs = @@ -269,13 +276,13 @@ public RelNode visit(Aggregate aggregate) throws RuntimeException { List aggregateCalls = aggregate.getMeasures().stream() - .map(this::fromMeasure) + .map(measure -> fromMeasure(measure, context)) .collect(java.util.stream.Collectors.toList()); RelNode node = relBuilder.push(child).aggregate(groupKey, aggregateCalls).build(); return applyRemap(node, aggregate.getRemap()); } - private AggregateCall fromMeasure(Aggregate.Measure measure) { + private AggregateCall fromMeasure(Aggregate.Measure measure, EmptyVisitationContext context) { var eArgs = measure.getFunction().arguments(); var arguments = IntStream.range(0, measure.getFunction().arguments().size()) @@ -283,7 +290,11 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { i -> eArgs .get(i) - .accept(measure.getFunction().declaration(), i, expressionRexConverter)) + .accept( + measure.getFunction().declaration(), + i, + expressionRexConverter, + context)) .collect(java.util.stream.Collectors.toList()); var operator = aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc( @@ -318,7 +329,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { int filterArg = -1; if (measure.getPreMeasureFilter().isPresent()) { - RexNode filter = measure.getPreMeasureFilter().get().accept(expressionRexConverter); + RexNode filter = measure.getPreMeasureFilter().get().accept(expressionRexConverter, context); filterArg = ((RexInputRef) filter).getIndex(); } @@ -327,7 +338,7 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { relCollation = RelCollations.of( measure.getFunction().sort().stream() - .map(sortField -> toRelFieldCollation(sortField)) + .map(sortField -> toRelFieldCollation(sortField, context)) .collect(Collectors.toList())); } @@ -346,19 +357,19 @@ private AggregateCall fromMeasure(Aggregate.Measure measure) { } @Override - public RelNode visit(Sort sort) throws RuntimeException { - RelNode child = sort.getInput().accept(this); + public RelNode visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { + RelNode child = sort.getInput().accept(this, context); List sortExpressions = sort.getSortFields().stream() - .map(this::directedRexNode) - .collect(java.util.stream.Collectors.toList()); + .map(sortField -> directedRexNode(sortField, context)) + .collect(Collectors.toList()); RelNode node = relBuilder.push(child).sort(sortExpressions).build(); return applyRemap(node, sort.getRemap()); } - private RexNode directedRexNode(Expression.SortField sortField) { + private RexNode directedRexNode(Expression.SortField sortField, EmptyVisitationContext context) { var expression = sortField.expr(); - var rexNode = expression.accept(expressionRexConverter); + var rexNode = expression.accept(expressionRexConverter, context); var sortDirection = sortField.direction(); return switch (sortDirection) { case ASC_NULLS_FIRST -> relBuilder.nullsFirst(rexNode); @@ -371,8 +382,8 @@ private RexNode directedRexNode(Expression.SortField sortField) { } @Override - public RelNode visit(Fetch fetch) throws RuntimeException { - RelNode child = fetch.getInput().accept(this); + public RelNode visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { + RelNode child = fetch.getInput().accept(this, context); var optCount = fetch.getCount(); long count = optCount.orElse(-1L); var offset = fetch.getOffset(); @@ -386,9 +397,10 @@ public RelNode visit(Fetch fetch) throws RuntimeException { return applyRemap(node, fetch.getRemap()); } - private RelFieldCollation toRelFieldCollation(Expression.SortField sortField) { + private RelFieldCollation toRelFieldCollation( + Expression.SortField sortField, EmptyVisitationContext context) { var expression = sortField.expr(); - var rex = expression.accept(expressionRexConverter); + var rex = expression.accept(expressionRexConverter, context); var sortDirection = sortField.direction(); RexSlot rexSlot = (RexSlot) rex; int fieldIndex = rexSlot.getIndex(); @@ -414,7 +426,7 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField) { } @Override - public RelNode visitFallback(Rel rel) throws RuntimeException { + public RelNode visitFallback(Rel rel, EmptyVisitationContext context) throws RuntimeException { throw new UnsupportedOperationException( String.format( "Rel %s of type %s not handled by visitor type %s.", diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index c3089e836..6001c1755 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -6,6 +6,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelCopyOnWriteVisitor; import io.substrait.type.NamedStruct; +import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -90,7 +91,7 @@ public RelNode convert(Rel rel) { CalciteSchema rootSchema = toSchema(rel); RelBuilder relBuilder = createRelBuilder(rootSchema); SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); - return rel.accept(converter); + return rel.accept(converter, null); } /** @@ -176,13 +177,13 @@ private NamedStructGatherer() { public static Map, NamedStruct> gatherTables(Rel rel) { var visitor = new NamedStructGatherer(); - rel.accept(visitor); + rel.accept(visitor, null); return visitor.tableMap; } @Override - public Optional visit(NamedScan namedScan) { - Optional result = super.visit(namedScan); + public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { + Optional result = super.visit(namedScan, context); List tableName = namedScan.getNames(); tableMap.put(tableName, namedScan.getInitialSchema()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java index 3d87fc392..d4cfc3dae 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java @@ -5,6 +5,7 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.isthmus.*; import io.substrait.type.Type; +import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; import java.util.List; import java.util.Optional; @@ -49,7 +50,8 @@ public class CallConverters { * is stored within a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link * org.apache.calcite.rex.RexLiteral} and then re-interpreted to have the correct type. * - *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral)} for this conversion. + *

See {@link ExpressionRexConverter#visit(Expression.UserDefinedLiteral, + * EmptyVisitationContext)} for this conversion. * *

When converting from Calcite to Substrait, this call converter extracts the {@link * Expression.UserDefinedLiteral} that was stored. diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index e6342bfc7..58d183d19 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -21,6 +21,7 @@ import io.substrait.type.StringTypeVisitor; import io.substrait.type.Type; import io.substrait.util.DecimalUtil; +import io.substrait.util.EmptyVisitationContext; import java.math.BigDecimal; import java.util.Collections; import java.util.List; @@ -53,8 +54,9 @@ * ExpressionVisitor that converts Substrait Expression into Calcite Rex. Unsupported Expression * node will call visitFallback and throw UnsupportedOperationException. */ -public class ExpressionRexConverter extends AbstractExpressionVisitor - implements FunctionArg.FuncArgVisitor { +public class ExpressionRexConverter + extends AbstractExpressionVisitor + implements FunctionArg.FuncArgVisitor { protected final RelDataTypeFactory typeFactory; protected final TypeConverter typeConverter; protected final RexBuilder rexBuilder; @@ -95,77 +97,90 @@ public void setRelNodeConverter(final SubstraitRelNodeConverter substraitRelNode } @Override - public RexNode visit(Expression.NullLiteral expr) throws RuntimeException { + public RexNode visit(Expression.NullLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral(null, typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.UserDefinedLiteral expr) throws RuntimeException { + public RexNode visit(Expression.UserDefinedLiteral expr, EmptyVisitationContext context) + throws RuntimeException { var binaryLiteral = rexBuilder.makeBinaryLiteral(new ByteString(expr.value().toByteArray())); var type = typeConverter.toCalcite(typeFactory, expr.getType()); return rexBuilder.makeReinterpretCast(type, binaryLiteral, rexBuilder.makeLiteral(false)); } @Override - public RexNode visit(Expression.BoolLiteral expr) throws RuntimeException { + public RexNode visit(Expression.BoolLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); } @Override - public RexNode visit(Expression.I8Literal expr) throws RuntimeException { + public RexNode visit(Expression.I8Literal expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.I16Literal expr) throws RuntimeException { + public RexNode visit(Expression.I16Literal expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.I32Literal expr) throws RuntimeException { + public RexNode visit(Expression.I32Literal expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.I64Literal expr) throws RuntimeException { + public RexNode visit(Expression.I64Literal expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.FP32Literal expr) throws RuntimeException { + public RexNode visit(Expression.FP32Literal expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.FP64Literal expr) throws RuntimeException { + public RexNode visit(Expression.FP64Literal expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.FixedCharLiteral expr) throws RuntimeException { + public RexNode visit(Expression.FixedCharLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral(expr.value()); } @Override - public RexNode visit(Expression.StrLiteral expr) throws RuntimeException { + public RexNode visit(Expression.StrLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType()), true); } @Override - public RexNode visit(Expression.VarCharLiteral expr) throws RuntimeException { + public RexNode visit(Expression.VarCharLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType()), true); } @Override - public RexNode visit(Expression.FixedBinaryLiteral expr) throws RuntimeException { + public RexNode visit(Expression.FixedBinaryLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( new ByteString(expr.value().toByteArray()), typeConverter.toCalcite(typeFactory, expr.getType()), @@ -173,7 +188,8 @@ public RexNode visit(Expression.FixedBinaryLiteral expr) throws RuntimeException } @Override - public RexNode visit(Expression.BinaryLiteral expr) throws RuntimeException { + public RexNode visit(Expression.BinaryLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( new ByteString(expr.value().toByteArray()), typeConverter.toCalcite(typeFactory, expr.getType()), @@ -181,7 +197,8 @@ public RexNode visit(Expression.BinaryLiteral expr) throws RuntimeException { } @Override - public RexNode visit(Expression.TimeLiteral expr) throws RuntimeException { + public RexNode visit(Expression.TimeLiteral expr, EmptyVisitationContext context) + throws RuntimeException { // Expression.TimeLiteral is Microseconds // Construct a TimeString : // 1. Truncate microseconds to seconds @@ -198,39 +215,45 @@ public RexNode visit(Expression.TimeLiteral expr) throws RuntimeException { } @Override - public RexNode visit(SingleOrList expr) throws RuntimeException { - var lhs = expr.condition().accept(this); + public RexNode visit(SingleOrList expr, EmptyVisitationContext context) throws RuntimeException { + var lhs = expr.condition().accept(this, context); return rexBuilder.makeIn( - lhs, expr.options().stream().map(e -> e.accept(this)).collect(Collectors.toList())); + lhs, + expr.options().stream().map(e -> e.accept(this, context)).collect(Collectors.toList())); } @Override - public RexNode visit(Expression.DateLiteral expr) throws RuntimeException { + public RexNode visit(Expression.DateLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( expr.value(), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.TimestampLiteral expr) throws RuntimeException { + public RexNode visit(Expression.TimestampLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(TimestampTZLiteral expr) throws RuntimeException { + public RexNode visit(TimestampTZLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value()), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(PrecisionTimestampLiteral expr) throws RuntimeException { + public RexNode visit(PrecisionTimestampLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value(), expr.precision()), typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(PrecisionTimestampTZLiteral expr) throws RuntimeException { + public RexNode visit(PrecisionTimestampTZLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeLiteral( getTimestampString(expr.value(), expr.precision()), typeConverter.toCalcite(typeFactory, expr.getType())); @@ -274,7 +297,8 @@ private TimestampString getTimestampString(long value, int precision) { } @Override - public RexNode visit(Expression.IntervalYearLiteral expr) throws RuntimeException { + public RexNode visit(Expression.IntervalYearLiteral expr, EmptyVisitationContext context) + throws RuntimeException { return rexBuilder.makeIntervalLiteral( new BigDecimal(expr.years() * 12 + expr.months()), YEAR_MONTH_INTERVAL); } @@ -282,7 +306,8 @@ public RexNode visit(Expression.IntervalYearLiteral expr) throws RuntimeExceptio private static final long MILLIS_IN_DAY = TimeUnit.DAYS.toMillis(1); @Override - public RexNode visit(Expression.IntervalDayLiteral expr) throws RuntimeException { + public RexNode visit(Expression.IntervalDayLiteral expr, EmptyVisitationContext context) + throws RuntimeException { long milliseconds = expr.precision() > 3 ? (expr.subseconds() / (int) Math.pow(10, expr.precision() - 3)) @@ -294,66 +319,81 @@ public RexNode visit(Expression.IntervalDayLiteral expr) throws RuntimeException } @Override - public RexNode visit(Expression.DecimalLiteral expr) throws RuntimeException { + public RexNode visit(Expression.DecimalLiteral expr, EmptyVisitationContext context) + throws RuntimeException { byte[] value = expr.value().toByteArray(); BigDecimal decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale(), 16); return rexBuilder.makeLiteral(decimal, typeConverter.toCalcite(typeFactory, expr.getType())); } @Override - public RexNode visit(Expression.ListLiteral expr) throws RuntimeException { + public RexNode visit(Expression.ListLiteral expr, EmptyVisitationContext context) + throws RuntimeException { List args = - expr.values().stream().map(l -> l.accept(this)).collect(Collectors.toList()); + expr.values().stream().map(l -> l.accept(this, context)).collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, args); } @Override - public RexNode visit(Expression.EmptyListLiteral expr) throws RuntimeException { + public RexNode visit(Expression.EmptyListLiteral expr, EmptyVisitationContext context) + throws RuntimeException { var calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); return rexBuilder.makeCall( calciteType, SqlStdOperatorTable.ARRAY_VALUE_CONSTRUCTOR, Collections.emptyList()); } @Override - public RexNode visit(Expression.MapLiteral expr) throws RuntimeException { + public RexNode visit(Expression.MapLiteral expr, EmptyVisitationContext context) + throws RuntimeException { var args = expr.values().entrySet().stream() - .flatMap(entry -> Stream.of(entry.getKey().accept(this), entry.getValue().accept(this))) + .flatMap( + entry -> + Stream.of( + entry.getKey().accept(this, context), + entry.getValue().accept(this, context))) .collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.MAP_VALUE_CONSTRUCTOR, args); } @Override - public RexNode visit(Expression.IfThen expr) throws RuntimeException { + public RexNode visit(Expression.IfThen expr, EmptyVisitationContext context) + throws RuntimeException { // In Calcite, the arguments to the CASE operator are given as: // ... ... Stream ifThenArgs = expr.ifClauses().stream() .flatMap( - clause -> Stream.of(clause.condition().accept(this), clause.then().accept(this))); - Stream elseArg = Stream.of(expr.elseClause().accept(this)); + clause -> + Stream.of( + clause.condition().accept(this, context), + clause.then().accept(this, context))); + Stream elseArg = Stream.of(expr.elseClause().accept(this, context)); List args = Stream.concat(ifThenArgs, elseArg).collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } @Override - public RexNode visit(Switch expr) throws RuntimeException { - RexNode match = expr.match().accept(this); + public RexNode visit(Switch expr, EmptyVisitationContext context) throws RuntimeException { + RexNode match = expr.match().accept(this, context); Stream caseThenArgs = expr.switchClauses().stream() .flatMap( clause -> Stream.of( rexBuilder.makeCall( - SqlStdOperatorTable.EQUALS, match, clause.condition().accept(this)), - clause.then().accept(this))); - Stream defaultArg = Stream.of(expr.defaultClause().accept(this)); + SqlStdOperatorTable.EQUALS, + match, + clause.condition().accept(this, context)), + clause.then().accept(this, context))); + Stream defaultArg = Stream.of(expr.defaultClause().accept(this, context)); List args = Stream.concat(caseThenArgs, defaultArg).collect(Collectors.toList()); return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } @Override - public RexNode visit(Expression.ScalarFunctionInvocation expr) throws RuntimeException { + public RexNode visit(Expression.ScalarFunctionInvocation expr, EmptyVisitationContext context) + throws RuntimeException { SqlOperator operator = scalarFunctionConverter .getSqlOperatorFromSubstraitFunc(expr.declaration().key(), expr.outputType()) @@ -366,7 +406,7 @@ public RexNode visit(Expression.ScalarFunctionInvocation expr) throws RuntimeExc var eArgs = scalarFunctionConverter.getExpressionArguments(expr); var args = IntStream.range(0, eArgs.size()) - .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this)) + .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this, context)) .collect(Collectors.toList()); RelDataType returnType = typeConverter.toCalcite(typeFactory, expr.outputType()); @@ -381,7 +421,8 @@ private String callConversionFailureMessage( } @Override - public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeException { + public RexNode visit(Expression.WindowFunctionInvocation expr, EmptyVisitationContext context) + throws RuntimeException { SqlOperator operator = windowFunctionConverter .getSqlOperatorFromSubstraitFunc(expr.declaration().key(), expr.outputType()) @@ -396,11 +437,11 @@ public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeExc List eArgs = expr.arguments(); List args = IntStream.range(0, eArgs.size()) - .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this)) + .mapToObj(i -> eArgs.get(i).accept(expr.declaration(), i, this, context)) .collect(Collectors.toList()); List partitionKeys = - expr.partitionBy().stream().map(e -> e.accept(this)).collect(Collectors.toList()); + expr.partitionBy().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); ImmutableList orderKeys = expr.sort().stream() @@ -415,7 +456,7 @@ public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeExc case CLUSTERED -> throw new IllegalArgumentException( "SORT_DIRECTION_CLUSTERED is not supported"); }; - return new RexFieldCollation(sf.expr().accept(this), direction); + return new RexFieldCollation(sf.expr().accept(this, context), direction); }) .collect(ImmutableList.toImmutableList()); @@ -461,10 +502,11 @@ public RexNode visit(Expression.WindowFunctionInvocation expr) throws RuntimeExc } @Override - public RexNode visit(Expression.InPredicate expr) throws RuntimeException { + public RexNode visit(Expression.InPredicate expr, EmptyVisitationContext context) + throws RuntimeException { List needles = - expr.needles().stream().map(e -> e.accept(this)).collect(Collectors.toList()); - RelNode rel = expr.haystack().accept(relNodeConverter); + expr.needles().stream().map(e -> e.accept(this, context)).collect(Collectors.toList()); + RelNode rel = expr.haystack().accept(relNodeConverter, context); return RexSubQuery.in(rel, ImmutableList.copyOf(needles)); } @@ -529,14 +571,18 @@ private String convert(FunctionArg a) { } @Override - public RexNode visit(Expression.Cast expr) throws RuntimeException { + public RexNode visit(Expression.Cast expr, EmptyVisitationContext context) + throws RuntimeException { var safeCast = expr.failureBehavior() == FailureBehavior.RETURN_NULL; return rexBuilder.makeAbstractCast( - typeConverter.toCalcite(typeFactory, expr.getType()), expr.input().accept(this), safeCast); + typeConverter.toCalcite(typeFactory, expr.getType()), + expr.input().accept(this, context), + safeCast); } @Override - public RexNode visit(FieldReference expr) throws RuntimeException { + public RexNode visit(FieldReference expr, EmptyVisitationContext context) + throws RuntimeException { if (expr.isSimpleRootReference()) { var segment = expr.segments().get(0); @@ -551,11 +597,11 @@ public RexNode visit(FieldReference expr) throws RuntimeException { return rexInputRef; } - return visitFallback(expr); + return visitFallback(expr, context); } @Override - public RexNode visitFallback(Expression expr) { + public RexNode visitFallback(Expression expr, EmptyVisitationContext context) { throw new UnsupportedOperationException( String.format( "Expression %s of type %s not handled by visitor type %s.", @@ -563,13 +609,15 @@ public RexNode visitFallback(Expression expr) { } @Override - public RexNode visitExpr(SimpleExtension.Function fnDef, int argIdx, Expression e) + public RexNode visitExpr( + SimpleExtension.Function fnDef, int argIdx, Expression e, EmptyVisitationContext context) throws RuntimeException { - return e.accept(this); + return e.accept(this, context); } @Override - public RexNode visitType(SimpleExtension.Function fnDef, int argIdx, Type t) + public RexNode visitType( + SimpleExtension.Function fnDef, int argIdx, Type t, EmptyVisitationContext context) throws RuntimeException { throw new UnsupportedOperationException( String.format( @@ -578,7 +626,8 @@ public RexNode visitType(SimpleExtension.Function fnDef, int argIdx, Type t) } @Override - public RexNode visitEnumArg(SimpleExtension.Function fnDef, int argIdx, EnumArg e) + public RexNode visitEnumArg( + SimpleExtension.Function fnDef, int argIdx, EnumArg e, EmptyVisitationContext context) throws RuntimeException { return EnumConverter.toRex(rexBuilder, fnDef, argIdx, e) @@ -591,14 +640,15 @@ public RexNode visitEnumArg(SimpleExtension.Function fnDef, int argIdx, EnumArg } @Override - public RexNode visit(ScalarSubquery expr) throws RuntimeException { - RelNode inputRelnode = expr.input().accept(relNodeConverter); + public RexNode visit(ScalarSubquery expr, EmptyVisitationContext context) + throws RuntimeException { + RelNode inputRelnode = expr.input().accept(relNodeConverter, context); return RexSubQuery.scalar(inputRelnode); } @Override - public RexNode visit(SetPredicate expr) throws RuntimeException { - RelNode inputRelnode = expr.tuples().accept(relNodeConverter); + public RexNode visit(SetPredicate expr, EmptyVisitationContext context) throws RuntimeException { + RelNode inputRelnode = expr.tuples().accept(relNodeConverter, context); switch (expr.predicateOp()) { case PREDICATE_OP_EXISTS: return RexSubQuery.exists(inputRelnode); diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java index 5a347c85c..8202b41fb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteCallTest.java @@ -106,7 +106,7 @@ private void test( consumer.accept(func); if (bidirectional) { - RexNode convertedCall = expression.accept(expressionRexConverter); + RexNode convertedCall = expression.accept(expressionRexConverter, null); assertEquals(call, convertedCall); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java index 2d5afef47..9816c905c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java @@ -183,7 +183,8 @@ void tIntervalYearMonthWithPrecision() { assertEquals(intervalYearMonthExpr, intervalYearMonth.accept(rexExpressionConverter)); // expression -> rex - RexLiteral convertedRex = (RexLiteral) intervalYearMonthExpr.accept(expressionRexConverter); + RexLiteral convertedRex = + (RexLiteral) intervalYearMonthExpr.accept(expressionRexConverter, null); // Compare value only. Ignore the precision in SqlIntervalQualifier (which is used to parse // input string). @@ -231,7 +232,7 @@ void tIntervalDay() { assertEquals(intervalDayExpr, convertedExpr); // expression -> rex - RexLiteral convertedRex = (RexLiteral) intervalDayExpr.accept(expressionRexConverter); + RexLiteral convertedRex = (RexLiteral) intervalDayExpr.accept(expressionRexConverter, null); // Compare value only. Ignore the precision in SqlIntervalQualifier in comparison. assertEquals( @@ -255,7 +256,7 @@ void tIntervalYear() { assertEquals(intervalYearExpr, intervalYear.accept(rexExpressionConverter)); // expression -> rex - RexLiteral convertedRex = (RexLiteral) intervalYearExpr.accept(expressionRexConverter); + RexLiteral convertedRex = (RexLiteral) intervalYearExpr.accept(expressionRexConverter, null); // Compare value only. Ignore the precision in SqlIntervalQualifier in comparison. assertEquals( @@ -280,7 +281,7 @@ void tIntervalMonth() { assertEquals(intervalMonthExpr, intervalMonth.accept(rexExpressionConverter)); // expression -> rex - RexLiteral convertedRex = (RexLiteral) intervalMonthExpr.accept(expressionRexConverter); + RexLiteral convertedRex = (RexLiteral) intervalMonthExpr.accept(expressionRexConverter, null); // Compare value only. Ignore the precision in SqlIntervalQualifier in comparison. assertEquals( @@ -386,7 +387,7 @@ public void test(Expression expression, RexNode rex) { // bi-directional test : 1) rex -> substrait, substrait -> rex2. Compare rex == rex2 public void bitest(Expression expression, RexNode rex) { assertEquals(expression, rex.accept(rexExpressionConverter)); - RexNode convertedRex = expression.accept(expressionRexConverter); + RexNode convertedRex = expression.accept(expressionRexConverter, null); assertEquals(rex, convertedRex); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java index 1f54c445f..7a165496c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ExpressionConvertabilityTest.java @@ -71,7 +71,7 @@ public void inPredicate() throws IOException, SqlParseException { @Test public void singleOrList() { Expression singleOrList = b.singleOrList(b.fieldReference(commonTable, 0), b.i32(5), b.i32(10)); - RexNode rexNode = singleOrList.accept(converter); + RexNode rexNode = singleOrList.accept(converter, null); Expression substraitExpression = rexNode.accept( new RexExpressionConverter( @@ -93,7 +93,7 @@ public void switchExpression() { b.fieldReference(commonTable, 0), List.of(b.switchClause(b.i32(5), b.i32(1)), b.switchClause(b.i32(10), b.i32(2))), b.i32(3)); - RexNode rexNode = switchExpression.accept(converter); + RexNode rexNode = switchExpression.accept(converter, null); Expression expression = rexNode.accept( new RexExpressionConverter( @@ -131,7 +131,8 @@ void assertExpressionEquality(Expression expected, Expression actual) { // go the extra mile and convert both inputs to protobuf // helps verify that the protobuf conversion is not broken assertEquals( - expected.accept(expressionProtoConverter), actual.accept(expressionProtoConverter)); + expected.accept(expressionProtoConverter, null), + actual.accept(expressionProtoConverter, null)); } @Test @@ -147,7 +148,7 @@ void assertPrecisionTimestampLiteral(int precision) { .value(0) .precision(precision) .build() - .accept(converter); + .accept(converter, null); assertInstanceOf(RexLiteral.class, calciteExpr); } @@ -164,7 +165,7 @@ void assertPrecisionTimestampTZLiteral(int precision) { .value(0) .precision(precision) .build() - .accept(converter); + .accept(converter, null); assertInstanceOf(RexLiteral.class, calciteExpr); } @@ -231,6 +232,6 @@ void assertThrowsUnsupportedPrecisionPrecisionTimestampTZLiteral(int precision) } void assertThrowsExpressionLiteral(Expression.Literal expr) { - assertThrows(UnsupportedOperationException.class, () -> expr.accept(converter)); + assertThrows(UnsupportedOperationException.class, () -> expr.accept(converter, null)); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java index 070435674..49e60835d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/FunctionConversionTest.java @@ -59,7 +59,7 @@ public void subtractDateIDay() { ExpressionCreator.date(false, 10561), ExpressionCreator.intervalDay(false, 120, 0, 0, 6)); - var calciteExpr = expr.accept(expressionRexConverter); + var calciteExpr = expr.accept(expressionRexConverter, null); assertEquals( TypeConverter.DEFAULT.toCalcite(typeFactory, TypeCreator.REQUIRED.DATE), calciteExpr.getType()); @@ -79,7 +79,7 @@ public void extractTimestampTzScalarFunction() { Expression.TimestampTZLiteral.builder().value(0).build(), Expression.StrLiteral.builder().value("GMT").build()); - RexNode calciteExpr = reqTstzFn.accept(expressionRexConverter); + RexNode calciteExpr = reqTstzFn.accept(expressionRexConverter, null); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); @@ -100,7 +100,7 @@ public void extractPrecisionTimestampTzScalarFunction() { Expression.PrecisionTimestampTZLiteral.builder().value(0).precision(6).build(), Expression.StrLiteral.builder().value("GMT").build()); - RexNode calciteExpr = reqPtstzFn.accept(expressionRexConverter); + RexNode calciteExpr = reqPtstzFn.accept(expressionRexConverter, null); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); @@ -120,7 +120,7 @@ public void extractTimestampScalarFunction() { EnumArg.builder().value("MONTH").build(), Expression.TimestampLiteral.builder().value(0).build()); - RexNode calciteExpr = reqTsFn.accept(expressionRexConverter); + RexNode calciteExpr = reqTsFn.accept(expressionRexConverter, null); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); @@ -138,7 +138,7 @@ public void extractPrecisionTimestampScalarFunction() { EnumArg.builder().value("MONTH").build(), Expression.PrecisionTimestampLiteral.builder().value(0).precision(6).build()); - RexNode calciteExpr = reqPtsFn.accept(expressionRexConverter); + RexNode calciteExpr = reqPtsFn.accept(expressionRexConverter, null); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); @@ -156,7 +156,7 @@ public void extractDateScalarFunction() { EnumArg.builder().value("MONTH").build(), Expression.DateLiteral.builder().value(0).build()); - RexNode calciteExpr = reqDateFn.accept(expressionRexConverter); + RexNode calciteExpr = reqDateFn.accept(expressionRexConverter, null); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); @@ -174,7 +174,7 @@ public void extractTimeScalarFunction() { EnumArg.builder().value("MINUTE").build(), Expression.TimeLiteral.builder().value(0).build()); - RexNode calciteExpr = reqTimeFn.accept(expressionRexConverter); + RexNode calciteExpr = reqTimeFn.accept(expressionRexConverter, null); assertEquals(SqlKind.EXTRACT, calciteExpr.getKind()); assertInstanceOf(RexCall.class, calciteExpr); @@ -195,7 +195,8 @@ public void unsupportedExtractTimestampTzWithIndexing() { Expression.StrLiteral.builder().value("GMT").build()); assertThrows( - UnsupportedOperationException.class, () -> reqReqTstzFn.accept(expressionRexConverter)); + UnsupportedOperationException.class, + () -> reqReqTstzFn.accept(expressionRexConverter, null)); } @Test @@ -211,7 +212,8 @@ public void unsupportedExtractPrecisionTimestampTzWithIndexing() { Expression.StrLiteral.builder().value("GMT").build()); assertThrows( - UnsupportedOperationException.class, () -> reqReqPtstzFn.accept(expressionRexConverter)); + UnsupportedOperationException.class, + () -> reqReqPtstzFn.accept(expressionRexConverter, null)); } @Test @@ -226,7 +228,7 @@ public void unsupportedExtractTimestampWithIndexing() { Expression.TimestampLiteral.builder().value(0).build()); assertThrows( - UnsupportedOperationException.class, () -> reqReqTsFn.accept(expressionRexConverter)); + UnsupportedOperationException.class, () -> reqReqTsFn.accept(expressionRexConverter, null)); } @Test @@ -241,7 +243,8 @@ public void unsupportedExtractPrecisionTimestampWithIndexing() { Expression.PrecisionTimestampLiteral.builder().value(0).precision(6).build()); assertThrows( - UnsupportedOperationException.class, () -> reqReqPtsFn.accept(expressionRexConverter)); + UnsupportedOperationException.class, + () -> reqReqPtsFn.accept(expressionRexConverter, null)); } @Test @@ -256,7 +259,8 @@ public void unsupportedExtractDateWithIndexing() { Expression.DateLiteral.builder().value(0).build()); assertThrows( - UnsupportedOperationException.class, () -> reqReqDateFn.accept(expressionRexConverter)); + UnsupportedOperationException.class, + () -> reqReqDateFn.accept(expressionRexConverter, null)); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index 366db0717..958f4ad0d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -11,6 +11,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelCopyOnWriteVisitor; import io.substrait.relation.Set; +import io.substrait.util.EmptyVisitationContext; import java.io.IOException; import java.util.Arrays; import java.util.Optional; @@ -74,9 +75,10 @@ public void crossJoin() throws IOException, SqlParseException { int[] counter = new int[1]; var crossJoinCountingVisitor = new RelCopyOnWriteVisitor() { - public Optional visit(Cross cross) throws RuntimeException { + public Optional visit(Cross cross, EmptyVisitationContext context) + throws RuntimeException { counter[0]++; - return super.visit(cross); + return super.visit(cross, context); } }; var featureBoard = ImmutableFeatureBoard.builder().build(); @@ -92,7 +94,7 @@ public Optional visit(Cross cross) throws RuntimeException { "orders" o """, new SqlToSubstrait(featureBoard)); - plan1.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor)); + plan1.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor, null)); assertEquals(1, counter[0]); Plan plan2 = @@ -106,7 +108,7 @@ public Optional visit(Cross cross) throws RuntimeException { "orders" o """, new SqlToSubstrait(featureBoard)); - plan2.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor)); + plan2.getRoots().forEach(t -> t.getInput().accept(crossJoinCountingVisitor, null)); assertEquals(2, counter[0]); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index fe3eac106..f17d3485a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -14,6 +14,7 @@ import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; import io.substrait.relation.RelCopyOnWriteVisitor; +import io.substrait.util.EmptyVisitationContext; import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -164,7 +165,7 @@ public void replaceCountDistinctsInUnion() throws IOException, SqlParseException private static class HasTableReference { public boolean hasTableReference(Plan plan, String name) { HasTableReferenceVisitor visitor = new HasTableReferenceVisitor(Arrays.asList(name)); - plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor)); + plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null)); return (visitor.hasTableReference()); } @@ -181,9 +182,9 @@ public boolean hasTableReference() { } @Override - public Optional visit(NamedScan namedScan) { + public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { this.hasTableReference |= namedScan.getNames().equals(tableName); - return super.visit(namedScan); + return super.visit(namedScan, context); } } } @@ -198,7 +199,7 @@ private static class CountCountDistinct { public int getCountDistincts(Plan plan) { CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); - plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor)); + plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null)); return visitor.getCountDistincts(); } @@ -210,7 +211,7 @@ public int getCountDistincts() { } @Override - public Optional visit(Aggregate aggregate) { + public Optional visit(Aggregate aggregate, EmptyVisitationContext context) { countDistincts += aggregate.getMeasures().stream() .filter( @@ -220,7 +221,7 @@ public Optional visit(Aggregate aggregate) { .invocation() .equals(Expression.AggregationInvocation.DISTINCT)) .count(); - return super.visit(aggregate); + return super.visit(aggregate, context); } } } @@ -229,7 +230,7 @@ private static class CountApproxCountDistinct { public int getApproxCountDistincts(Plan plan) { CountCountDistinctVisitor visitor = new CountCountDistinctVisitor(); - plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor)); + plan.getRoots().stream().forEach(r -> r.getInput().accept(visitor, null)); return visitor.getApproxCountDistincts(); } @@ -241,13 +242,13 @@ public int getApproxCountDistincts() { } @Override - public Optional visit(Aggregate aggregate) { + public Optional visit(Aggregate aggregate, EmptyVisitationContext context) { aproxCountDistincts += aggregate.getMeasures().stream() .filter( m -> m.getFunction().declaration().getAnchor().equals(APPROX_COUNT_DISTINCT)) .count(); - return super.visit(aggregate); + return super.visit(aggregate, context); } } } @@ -260,11 +261,12 @@ public ReplaceCountDistinctWithApprox() { } public Optional modify(Plan plan) { - return CopyOnWriteUtils.transformList( + return CopyOnWriteUtils.transformList( plan.getRoots(), - t -> + null, + (t, c) -> t.getInput() - .accept(visitor) + .accept(visitor, c) .map(u -> Plan.Root.builder().from(t).input(u).build())) .map(t -> Plan.builder().from(plan).roots(t).build()); } @@ -281,10 +283,12 @@ public ReplaceCountDistinctWithApproxVisitor( } @Override - public Optional visit(Aggregate aggregate) { - return CopyOnWriteUtils.transformList( + public Optional visit(Aggregate aggregate, EmptyVisitationContext context) { + return CopyOnWriteUtils + .transformList( aggregate.getMeasures(), - m -> { + context, + (m, c) -> { if (m.getFunction().invocation().equals(Expression.AggregationInvocation.DISTINCT) && m.getFunction().declaration().getAnchor().equals(COUNT)) { return Optional.of( diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index c71f1539f..17f246896 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -17,6 +17,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; +import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -60,7 +61,8 @@ void extensionMultiRelDetailTest() { void roundtrip(Rel pojo1) { // Substrait POJO 1 -> Substrait Proto - io.substrait.proto.Rel proto = pojo1.accept(new RelProtoConverter(new ExtensionCollector())); + io.substrait.proto.Rel proto = + pojo1.accept(new RelProtoConverter(new ExtensionCollector()), null); // Substrait Proto -> Substrait POJO 2 var pojo2 = (new CustomProtoRelConverter(new ExtensionCollector())).from(proto); @@ -68,7 +70,7 @@ void roundtrip(Rel pojo1) { // Substrait POJO 2 -> Calcite var calcite = - pojo2.accept(new CustomSubstraitRelNodeConverter(extensions, typeFactory, builder)); + pojo2.accept(new CustomSubstraitRelNodeConverter(extensions, typeFactory, builder), null); // Calcite -> Substrait POJO 3 var pojo3 = (new CustomSubstraitRelVisitor(typeFactory, extensions)).apply(calcite); @@ -192,10 +194,11 @@ public CustomSubstraitRelNodeConverter( super(extensions, typeFactory, relBuilder); } - public RelNode visit(ExtensionLeaf extensionLeaf) { + @Override + public RelNode visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) { if (extensionLeaf.getDetail() instanceof ColumnAppendDetail) { ColumnAppendDetail cad = (ColumnAppendDetail) extensionLeaf.getDetail(); - RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter); + RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); RelOptCluster cluster = relBuilder.getCluster(); RelTraitSet traits = cluster.traitSet(); return new ColumnAppenderRel( @@ -205,11 +208,12 @@ public RelNode visit(ExtensionLeaf extensionLeaf) { } @Override - public RelNode visit(ExtensionSingle extensionSingle) throws RuntimeException { + public RelNode visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + throws RuntimeException { if (extensionSingle.getDetail() instanceof ColumnAppendDetail) { ColumnAppendDetail cad = (ColumnAppendDetail) extensionSingle.getDetail(); - RelNode input = extensionSingle.getInput().accept(this); - RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter); + RelNode input = extensionSingle.getInput().accept(this, context); + RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); return new ColumnAppenderRel( input.getCluster(), input.getTraitSet(), literal, List.of(input)); } @@ -217,14 +221,15 @@ public RelNode visit(ExtensionSingle extensionSingle) throws RuntimeException { } @Override - public RelNode visit(ExtensionMulti extensionMulti) throws RuntimeException { + public RelNode visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + throws RuntimeException { if (extensionMulti.getDetail() instanceof ColumnAppendDetail) { ColumnAppendDetail cad = (ColumnAppendDetail) extensionMulti.getDetail(); List inputs = extensionMulti.getInputs().stream() - .map(input -> input.accept(this)) + .map(input -> input.accept(this, context)) .collect(Collectors.toList()); - RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter); + RexLiteral literal = (RexLiteral) cad.literal.accept(this.expressionRexConverter, context); return new ColumnAppenderRel( inputs.get(0).getCluster(), inputs.get(0).getTraitSet(), literal, inputs); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java index 090fc98a8..11d7d04e7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java @@ -51,7 +51,7 @@ public void switchExpression() { b.fieldReference(commonTable, 0), List.of(b.switchClause(b.i32(0), b.fieldReference(commonTable, 3))), b.bool(false)); - var calciteExpr = expr.accept(converter); + var calciteExpr = expr.accept(converter, null); assertTypeMatch(calciteExpr.getType(), N.BOOLEAN); } @@ -174,7 +174,7 @@ public void useSubstraitReturnTypeDuringScalarFunctionConversion() { b.i32(7), b.i32(42)); - RexNode calciteExpr = expr.accept(expressionRexConverter); + RexNode calciteExpr = expr.accept(expressionRexConverter, null); assertEquals(TypeConverter.DEFAULT.toCalcite(typeFactory, R.FP32), calciteExpr.getType()); } @@ -194,7 +194,7 @@ public void useSubstraitReturnTypeDuringWindowFunctionConversion() { WindowBound.UNBOUNDED, b.i32(42)); - RexNode calciteExpr = expr.accept(expressionRexConverter); + RexNode calciteExpr = expr.accept(expressionRexConverter, null); assertEquals(TypeConverter.DEFAULT.toCalcite(typeFactory, R.STRING), calciteExpr.getType()); } diff --git a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala index 2d4d3f833..5377f4257 100644 --- a/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala +++ b/spark/src/main/scala/io/substrait/debug/ExpressionToString.scala @@ -24,72 +24,77 @@ import io.substrait.expression.{Expression, FieldReference} import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, I64Literal, StrLiteral} import io.substrait.function.ToTypeString import io.substrait.util.DecimalUtil +import io.substrait.util.EmptyVisitationContext import scala.collection.JavaConverters.asScalaBufferConverter class ExpressionToString extends DefaultExpressionVisitor[String] { - override def visit(expr: DecimalLiteral): String = { + override def visit(expr: DecimalLiteral, context: EmptyVisitationContext): String = { val value = expr.value.toByteArray val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) decimal.toString } - override def visit(expr: StrLiteral): String = { + override def visit(expr: StrLiteral, context: EmptyVisitationContext): String = { expr.value() } - override def visit(expr: I32Literal): String = { + override def visit(expr: I32Literal, context: EmptyVisitationContext): String = { expr.value().toString } - override def visit(expr: I64Literal): String = { + override def visit(expr: I64Literal, context: EmptyVisitationContext): String = { expr.value().toString } - override def visit(expr: DateLiteral): String = { + override def visit(expr: DateLiteral, context: EmptyVisitationContext): String = { DateTimeUtils.toJavaDate(expr.value()).toString } - override def visit(expr: FieldReference): String = { + override def visit(expr: FieldReference, context: EmptyVisitationContext): String = { withFieldReference(expr)(i => "$" + i.toString) } - override def visit(expr: Expression.SingleOrList): String = { + override def visit(expr: Expression.SingleOrList, context: EmptyVisitationContext): String = { expr.toString } - override def visit(expr: Expression.ScalarFunctionInvocation): String = { + override def visit( + expr: Expression.ScalarFunctionInvocation, + context: EmptyVisitationContext): String = { val args = expr .arguments() .asScala .zipWithIndex .map { case (arg, i) => - arg.accept(expr.declaration(), i, this) + arg.accept(expr.declaration(), i, this, context) } .mkString(",") s"${expr.declaration().key()}[${expr.outputType().accept(ToTypeString.INSTANCE)}]($args)" } - override def visit(expr: Expression.UserDefinedLiteral): String = { + override def visit( + expr: Expression.UserDefinedLiteral, + context: EmptyVisitationContext): String = { expr.toString } - override def visit(expr: Expression.EmptyMapLiteral): String = { + override def visit(expr: Expression.EmptyMapLiteral, context: EmptyVisitationContext): String = { expr.toString } - override def visit(expr: Expression.Cast): String = { + override def visit(expr: Expression.Cast, context: EmptyVisitationContext): String = { expr.getType.toString } - override def visit(expr: Expression.InPredicate): String = { + override def visit(expr: Expression.InPredicate, context: EmptyVisitationContext): String = { expr.toString } - override def visit(expr: Expression.ScalarSubquery): String = { + override def visit(expr: Expression.ScalarSubquery, context: EmptyVisitationContext): String = { expr.toString } } diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 2703977df..65109f4b4 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -19,6 +19,7 @@ package io.substrait.debug import io.substrait.spark.DefaultRelVisitor import io.substrait.relation._ +import io.substrait.util.EmptyVisitationContext import scala.collection.mutable @@ -41,10 +42,10 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { } def apply(rel: Rel, maxFields: Int): String = { - rel.accept(this) + rel.accept(this, null) } - override def visit(fetch: Fetch): String = { + override def visit(fetch: Fetch, context: EmptyVisitationContext): String = { withBuilder(fetch, 7)( builder => { builder.append("offset=").append(fetch.getOffset) @@ -55,14 +56,14 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) }) } - override def visit(sort: Sort): String = { + override def visit(sort: Sort, context: EmptyVisitationContext): String = { withBuilder(sort, 5)( builder => { builder.append("sortFields=").append(sort.getSortFields) }) } - override def visit(join: Join): String = { + override def visit(join: Join, context: EmptyVisitationContext): String = { withBuilder(join, 5)( builder => { join.getCondition.ifPresent( @@ -80,10 +81,10 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(filter: Filter): String = { + override def visit(filter: Filter, context: EmptyVisitationContext): String = { withBuilder(filter, 7)( builder => { - builder.append(filter.getCondition.accept(expressionStringConverter)) + builder.append(filter.getCondition.accept(expressionStringConverter, context)) }) } @@ -101,7 +102,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(namedScan: NamedScan): String = { + override def visit(namedScan: NamedScan, context: EmptyVisitationContext): String = { withBuilder(namedScan, 10)( builder => { fillReadRel(namedScan, builder) @@ -116,7 +117,9 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(virtualTableScan: VirtualTableScan): String = { + override def visit( + virtualTableScan: VirtualTableScan, + context: EmptyVisitationContext): String = { withBuilder(virtualTableScan, 10)( builder => { fillReadRel(virtualTableScan, builder) @@ -131,14 +134,14 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(emptyScan: EmptyScan): String = { + override def visit(emptyScan: EmptyScan, context: EmptyVisitationContext): String = { withBuilder(emptyScan, 10)( builder => { fillReadRel(emptyScan, builder) }) } - override def visit(project: Project): String = { + override def visit(project: Project, context: EmptyVisitationContext): String = { withBuilder(project, 8)( builder => { builder @@ -147,7 +150,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(expand: Expand): String = { + override def visit(expand: Expand, context: EmptyVisitationContext): String = { withBuilder(expand, 8)( builder => { builder @@ -156,7 +159,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(aggregate: Aggregate): String = { + override def visit(aggregate: Aggregate, context: EmptyVisitationContext): String = { withBuilder(aggregate, 10)( builder => { builder @@ -168,7 +171,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(window: ConsistentPartitionWindow): String = { + override def visit(window: ConsistentPartitionWindow, context: EmptyVisitationContext): String = { withBuilder(window, 10)( builder => { builder @@ -181,7 +184,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(set: Set): String = { + override def visit(set: Set, context: EmptyVisitationContext): String = { withBuilder(set, 8)( builder => { builder @@ -190,7 +193,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(cross: Cross): String = { + override def visit(cross: Cross, context: EmptyVisitationContext): String = { withBuilder(cross, 10)( builder => { builder @@ -201,7 +204,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } - override def visit(localFiles: LocalFiles): String = { + override def visit(localFiles: LocalFiles, context: EmptyVisitationContext): String = { withBuilder(localFiles, 10)( builder => { builder diff --git a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala index d0d2e0d00..5f7137b14 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultExpressionVisitor.scala @@ -19,21 +19,30 @@ package io.substrait.spark import io.substrait.`type`.Type import io.substrait.expression._ import io.substrait.extension.SimpleExtension +import io.substrait.util.EmptyVisitationContext class DefaultExpressionVisitor[T] - extends AbstractExpressionVisitor[T, RuntimeException] - with FunctionArg.FuncArgVisitor[T, RuntimeException] { + extends AbstractExpressionVisitor[T, EmptyVisitationContext, RuntimeException] + with FunctionArg.FuncArgVisitor[T, EmptyVisitationContext, RuntimeException] { - override def visitFallback(expr: Expression): T = + override def visitFallback(expr: Expression, context: EmptyVisitationContext): T = throw new UnsupportedOperationException( s"Expression type ${expr.getClass.getCanonicalName} " + s"not handled by visitor type ${getClass.getCanonicalName}.") - override def visitType(fnDef: SimpleExtension.Function, argIdx: Int, t: Type): T = + override def visitType( + fnDef: SimpleExtension.Function, + argIdx: Int, + t: Type, + context: EmptyVisitationContext): T = throw new UnsupportedOperationException( s"FunctionArg $t not handled by visitor type ${getClass.getCanonicalName}.") - override def visitEnumArg(fnDef: SimpleExtension.Function, argIdx: Int, e: EnumArg): T = + override def visitEnumArg( + fnDef: SimpleExtension.Function, + argIdx: Int, + e: EnumArg, + context: EmptyVisitationContext): T = throw new UnsupportedOperationException( s"EnumArg(value=${e.value()}) not handled by visitor type ${getClass.getCanonicalName}.") @@ -45,14 +54,20 @@ class DefaultExpressionVisitor[T] case _ => throw new IllegalArgumentException(s"Unhandled type: $segment") } } else { - visitFallback(fieldReference) + visitFallback(fieldReference, null) } } - override def visitExpr(fnDef: SimpleExtension.Function, argIdx: Int, e: Expression): T = - e.accept(this) + override def visitExpr( + fnDef: SimpleExtension.Function, + argIdx: Int, + e: Expression, + context: EmptyVisitationContext): T = + e.accept(this, context) - override def visit(userDefinedLiteral: Expression.UserDefinedLiteral): T = { - visitFallback(userDefinedLiteral) + override def visit( + userDefinedLiteral: Expression.UserDefinedLiteral, + context: EmptyVisitationContext): T = { + visitFallback(userDefinedLiteral, context) } } diff --git a/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala b/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala index 7f1e181b5..d97ca5ac4 100644 --- a/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala +++ b/spark/src/main/scala/io/substrait/spark/DefaultRelVisitor.scala @@ -18,10 +18,11 @@ package io.substrait.spark import io.substrait.relation import io.substrait.relation.AbstractRelVisitor +import io.substrait.util.EmptyVisitationContext -class DefaultRelVisitor[T] extends AbstractRelVisitor[T, RuntimeException] { +class DefaultRelVisitor[T] extends AbstractRelVisitor[T, EmptyVisitationContext, RuntimeException] { - override def visitFallback(rel: relation.Rel): T = + override def visitFallback(rel: relation.Rel, context: EmptyVisitationContext): T = throw new UnsupportedOperationException( s"Type ${rel.getClass.getCanonicalName}" + s" not handled by visitor type ${getClass.getCanonicalName}.") diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index fa42c2463..35d5fb0a7 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -31,6 +31,7 @@ import io.substrait.{expression => exp} import io.substrait.expression.{EnumArg, Expression => SExpression} import io.substrait.extension.SimpleExtension import io.substrait.util.DecimalUtil +import io.substrait.util.EmptyVisitationContext import io.substrait.utils.Util import scala.collection.JavaConverters.{asScalaBufferConverter, mapAsScalaMapConverter} @@ -41,7 +42,7 @@ class ToSparkExpression( extends DefaultExpressionVisitor[Expression] with HasOutputStack[Seq[NamedExpression]] { - override def visit(expr: SExpression.BoolLiteral): Expression = { + override def visit(expr: SExpression.BoolLiteral, context: EmptyVisitationContext): Expression = { if (expr.value()) { Literal.TrueLiteral } else { @@ -49,69 +50,81 @@ class ToSparkExpression( } } - override def visit(expr: SExpression.I8Literal): Expression = { + override def visit(expr: SExpression.I8Literal, context: EmptyVisitationContext): Expression = { Literal(expr.value().asInstanceOf[Byte], ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.I16Literal): Expression = { + override def visit(expr: SExpression.I16Literal, context: EmptyVisitationContext): Expression = { Literal(expr.value().asInstanceOf[Short], ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.I32Literal): Expression = { + override def visit(expr: SExpression.I32Literal, context: EmptyVisitationContext): Expression = { Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.I64Literal): Expression = { + override def visit(expr: SExpression.I64Literal, context: EmptyVisitationContext): Expression = { Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.FP32Literal): Literal = { + override def visit(expr: SExpression.FP32Literal, context: EmptyVisitationContext): Literal = { Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.FP64Literal): Expression = { + override def visit(expr: SExpression.FP64Literal, context: EmptyVisitationContext): Expression = { Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.StrLiteral): Expression = { + override def visit(expr: SExpression.StrLiteral, context: EmptyVisitationContext): Expression = { Literal(UTF8String.fromString(expr.value()), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.FixedCharLiteral): Expression = { + override def visit( + expr: SExpression.FixedCharLiteral, + context: EmptyVisitationContext): Expression = { Literal(UTF8String.fromString(expr.value()), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.VarCharLiteral): Expression = { + override def visit( + expr: SExpression.VarCharLiteral, + context: EmptyVisitationContext): Expression = { Literal(UTF8String.fromString(expr.value()), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.BinaryLiteral): Literal = { + override def visit(expr: SExpression.BinaryLiteral, context: EmptyVisitationContext): Literal = { Literal(expr.value().toByteArray, ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.DecimalLiteral): Expression = { + override def visit( + expr: SExpression.DecimalLiteral, + context: EmptyVisitationContext): Expression = { val value = expr.value.toByteArray val decimal = DecimalUtil.getBigDecimalFromBytes(value, expr.scale, 16) Literal(Decimal(decimal), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.DateLiteral): Expression = { + override def visit(expr: SExpression.DateLiteral, context: EmptyVisitationContext): Expression = { Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.PrecisionTimestampLiteral): Literal = { + override def visit( + expr: SExpression.PrecisionTimestampLiteral, + context: EmptyVisitationContext): Literal = { // Spark timestamps are stored as a microseconds Long Util.assertMicroseconds(expr.precision()) Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.PrecisionTimestampTZLiteral): Literal = { + override def visit( + expr: SExpression.PrecisionTimestampTZLiteral, + context: EmptyVisitationContext): Literal = { // Spark timestamps are stored as a microseconds Long Util.assertMicroseconds(expr.precision()) Literal(expr.value(), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.IntervalDayLiteral): Literal = { + override def visit( + expr: SExpression.IntervalDayLiteral, + context: EmptyVisitationContext): Literal = { Util.assertMicroseconds(expr.precision()) // Spark uses a single microseconds Long as the "physical" type for DayTimeInterval val micros = @@ -120,48 +133,56 @@ class ToSparkExpression( Literal(micros, ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.IntervalYearLiteral): Literal = { + override def visit( + expr: SExpression.IntervalYearLiteral, + context: EmptyVisitationContext): Literal = { // Spark uses a single months Int as the "physical" type for YearMonthInterval val months = expr.years() * 12 + expr.months() Literal(months, ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.ListLiteral): Literal = { - val array = expr.values().asScala.map(value => value.accept(this).asInstanceOf[Literal].value) + override def visit(expr: SExpression.ListLiteral, context: EmptyVisitationContext): Literal = { + val array = + expr.values().asScala.map(value => value.accept(this, context).asInstanceOf[Literal].value) Literal.create(array, ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.EmptyListLiteral): Expression = { + override def visit( + expr: SExpression.EmptyListLiteral, + context: EmptyVisitationContext): Expression = { Literal.default(ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.MapLiteral): Literal = { + override def visit(expr: SExpression.MapLiteral, context: EmptyVisitationContext): Literal = { val map = expr.values().asScala.map { case (key, value) => ( - key.accept(this).asInstanceOf[Literal].value, - value.accept(this).asInstanceOf[Literal].value + key.accept(this, context).asInstanceOf[Literal].value, + value.accept(this, context).asInstanceOf[Literal].value ) } Literal.create(map, ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.EmptyMapLiteral): Literal = { + override def visit( + expr: SExpression.EmptyMapLiteral, + context: EmptyVisitationContext): Literal = { Literal.default(ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.StructLiteral): Literal = { + override def visit(expr: SExpression.StructLiteral, context: EmptyVisitationContext): Literal = { Literal.create( - Row.fromSeq(expr.fields.asScala.map(field => field.accept(this).asInstanceOf[Literal].value)), + Row.fromSeq( + expr.fields.asScala.map(field => field.accept(this, context).asInstanceOf[Literal].value)), ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.NullLiteral): Expression = { + override def visit(expr: SExpression.NullLiteral, context: EmptyVisitationContext): Expression = { Literal(null, ToSparkType.convert(expr.getType)) } - override def visit(expr: SExpression.Cast): Expression = { - val childExp = expr.input().accept(this) + override def visit(expr: SExpression.Cast, context: EmptyVisitationContext): Expression = { + val childExp = expr.input().accept(this, context) val tt = ToSparkType.convert(expr.getType) val tz = if (Cast.needsTimeZone(childExp.dataType, tt)) @@ -171,51 +192,55 @@ class ToSparkExpression( Cast(childExp, tt, tz) } - override def visit(expr: exp.FieldReference): Expression = { + override def visit(expr: exp.FieldReference, context: EmptyVisitationContext): Expression = { withFieldReference(expr)(i => currentOutput(i).clone()) } - override def visit(expr: SExpression.IfThen): Expression = { + override def visit(expr: SExpression.IfThen, context: EmptyVisitationContext): Expression = { val branches = expr .ifClauses() .asScala .map( ifClause => { - val predicate = ifClause.condition().accept(this) - val elseValue = ifClause.`then`().accept(this) + val predicate = ifClause.condition().accept(this, context) + val elseValue = ifClause.`then`().accept(this, context) (predicate, elseValue) }) - val default = expr.elseClause().accept(this) match { + val default = expr.elseClause().accept(this, context) match { case l: Literal if l.nullable => None case other => Some(other) } CaseWhen(branches, default) } - override def visit(expr: SExpression.ScalarSubquery): Expression = { + override def visit( + expr: SExpression.ScalarSubquery, + context: EmptyVisitationContext): Expression = { val rel = expr.input() val dataType = ToSparkType.convert(expr.getType) toLogicalPlan .map( relConverter => { - val plan = rel.accept(relConverter) + val plan = rel.accept(relConverter, context) require(plan.resolved) val result = ScalarSubquery(plan) SparkTypeUtil.sameType(result.dataType, dataType) result }) - .getOrElse(visitFallback(expr)) + .getOrElse(visitFallback(expr, context)) } - override def visit(expr: SExpression.SingleOrList): Expression = { - val value = expr.condition().accept(this) - val list = expr.options().asScala.map(e => e.accept(this)) + override def visit( + expr: SExpression.SingleOrList, + context: EmptyVisitationContext): Expression = { + val value = expr.condition().accept(this, context) + val list = expr.options().asScala.map(e => e.accept(this, context)) In(value, list) } - override def visit(expr: SExpression.InPredicate): Expression = { - val needles = expr.needles().asScala.map(e => e.accept(this)) - val haystack = expr.haystack().accept(toLogicalPlan.get) + override def visit(expr: SExpression.InPredicate, context: EmptyVisitationContext): Expression = { + val needles = expr.needles().asScala.map(e => e.accept(this, context)) + val haystack = expr.haystack().accept(toLogicalPlan.get, context) new InSubquery(needles, ListQuery(haystack, childOutputs = haystack.output)) { override def nullable: Boolean = expr.getType.nullable() } @@ -224,15 +249,18 @@ class ToSparkExpression( override def visitEnumArg( fnDef: SimpleExtension.Function, argIdx: Int, - e: EnumArg): Expression = { + e: EnumArg, + context: EmptyVisitationContext): Expression = { Enum(e.value.orElse("")) } - override def visit(expr: SExpression.ScalarFunctionInvocation): Expression = { + override def visit( + expr: SExpression.ScalarFunctionInvocation, + context: EmptyVisitationContext): Expression = { val eArgs = expr.arguments().asScala val args = eArgs.zipWithIndex.map { case (arg, i) => - arg.accept(expr.declaration(), i, this) + arg.accept(expr.declaration(), i, this, context) }.toList scalarFunctionConverter diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index afb8c16ba..63ba4c0e7 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -44,6 +44,7 @@ import io.substrait.relation.Expand.{ConsistentField, SwitchingField} import io.substrait.relation.LocalFiles import io.substrait.relation.Set.SetOp import io.substrait.relation.files.FileFormat +import io.substrait.util.EmptyVisitationContext import org.apache.hadoop.fs.Path import scala.collection.JavaConverters.asScalaBufferConverter @@ -64,7 +65,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) val function = measure.getFunction var arguments = function.arguments().asScala.zipWithIndex.map { case (arg, i) => - arg.accept(function.declaration(), i, expressionConverter) + arg.accept(function.declaration(), i, expressionConverter, null) } if (function.declaration.name == "count" && function.arguments.size == 0) { // HACK - count() needs to be rewritten as count(1) @@ -91,7 +92,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) }) val filter = Option(measure.getPreMeasureFilter.orElse(null)) - .map(_.accept(expressionConverter)) + .map(_.accept(expressionConverter, null)) AggregateExpression( aggregateFunction, @@ -106,15 +107,17 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) case other => Alias(other, toPrettySQL(other))() } - override def visit(aggregate: relation.Aggregate): LogicalPlan = { + override def visit( + aggregate: relation.Aggregate, + context: EmptyVisitationContext): LogicalPlan = { require(aggregate.getGroupings.size() == 1) - val child = aggregate.getInput.accept(this) + val child = aggregate.getInput.accept(this, context) withChild(child) { val groupBy = aggregate.getGroupings .get(0) .getExpressions .asScala - .map(expr => expr.accept(expressionConverter)) + .map(expr => expr.accept(expressionConverter, context)) val outputs = groupBy.map(toNamedExpression) val aggregateExpressions = @@ -123,18 +126,20 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(window: relation.ConsistentPartitionWindow): LogicalPlan = { - val child = window.getInput.accept(this) + override def visit( + window: relation.ConsistentPartitionWindow, + context: EmptyVisitationContext): LogicalPlan = { + val child = window.getInput.accept(this, context) withChild(child) { val partitions = window.getPartitionExpressions.asScala - .map(expr => expr.accept(expressionConverter)) + .map(expr => expr.accept(expressionConverter, context)) val sortOrders = window.getSorts.asScala.map(toSortOrder) val windowExpressions = window.getWindowFunctions.asScala .map( func => { val arguments = func.arguments().asScala.zipWithIndex.map { case (arg, i) => - arg.accept(func.declaration(), i, expressionConverter) + arg.accept(func.declaration(), i, expressionConverter, context) } val windowFunction = SparkExtension.toWindowFunction .getSparkExpressionFromSubstraitFunc(func.declaration.key, arguments) @@ -172,12 +177,12 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(join: relation.Join): LogicalPlan = { - val left = join.getLeft.accept(this) - val right = join.getRight.accept(this) + override def visit(join: relation.Join, context: EmptyVisitationContext): LogicalPlan = { + val left = join.getLeft.accept(this, context) + val right = join.getRight.accept(this, context) withChild(left, right) { val condition = Option(join.getCondition.orElse(null)) - .map(_.accept(expressionConverter)) + .map(_.accept(expressionConverter, context)) val joinType = join.getJoinType match { case relation.Join.JoinType.INNER => Inner @@ -197,9 +202,9 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(join: relation.Cross): LogicalPlan = { - val left = join.getLeft.accept(this) - val right = join.getRight.accept(this) + override def visit(join: relation.Cross, context: EmptyVisitationContext): LogicalPlan = { + val left = join.getLeft.accept(this, context) + val right = join.getRight.accept(this, context) withChild(left, right) { // TODO: Support different join types here when join types are added to cross rel for BNLJ // Currently, this will change both cross and inner join types to inner join @@ -208,7 +213,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } private def toSortOrder(sortField: SExpression.SortField): SortOrder = { - val expression = sortField.expr().accept(expressionConverter) + val expression = sortField.expr().accept(expressionConverter, null) val (direction, nullOrdering) = sortField.direction() match { case SExpression.SortDirection.ASC_NULLS_FIRST => (Ascending, NullsFirst) case SExpression.SortDirection.DESC_NULLS_FIRST => (Descending, NullsFirst) @@ -220,8 +225,8 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) SortOrder(expression, direction, nullOrdering, Seq.empty) } - override def visit(fetch: relation.Fetch): LogicalPlan = { - val child = fetch.getInput.accept(this) + override def visit(fetch: relation.Fetch, context: EmptyVisitationContext): LogicalPlan = { + val child = fetch.getInput.accept(this, context) val limit = fetch.getCount.orElse(-1).intValue() // -1 means unassigned here val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) @@ -239,8 +244,8 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(sort: relation.Sort): LogicalPlan = { - val child = sort.getInput.accept(this) + override def visit(sort: relation.Sort, context: EmptyVisitationContext): LogicalPlan = { + val child = sort.getInput.accept(this, context) withChild(child) { val sortOrders = sort.getSortFields.asScala.map(toSortOrder) Sort(sortOrders, global = true, child) @@ -267,8 +272,8 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(project: relation.Project): LogicalPlan = { - val child = project.getInput.accept(this) + override def visit(project: relation.Project, context: EmptyVisitationContext): LogicalPlan = { + val child = project.getInput.accept(this, context) val (output, createProject) = child match { case a: Aggregate => (a.aggregateExpressions, false) case other => (other.output, true) @@ -278,7 +283,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) withOutput(output) { val projectExprs = project.getExpressions.asScala - .map(expr => expr.accept(expressionConverter)) + .map(expr => expr.accept(expressionConverter, context)) val projectList = if (names.size == projectExprs.size) { projectExprs.zip(names).map { case (expr, name) => Alias(expr, name)() } } else { @@ -293,8 +298,8 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(expand: relation.Expand): LogicalPlan = { - val child = expand.getInput.accept(this) + override def visit(expand: relation.Expand, context: EmptyVisitationContext): LogicalPlan = { + val child = expand.getInput.accept(this, context) val names = fieldNames(expand).getOrElse( expand.getFields.asScala.zipWithIndex.map { case (_, i) => s"col$i" } ) @@ -304,7 +309,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) .map { case sf: SwitchingField => sf.getDuplicates.asScala - .map(expr => expr.accept(expressionConverter)) + .map(expr => expr.accept(expressionConverter, context)) .map(toNamedExpression) case _: ConsistentField => throw new UnsupportedOperationException("ConsistentField not currently supported") @@ -321,16 +326,16 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(filter: relation.Filter): LogicalPlan = { - val child = filter.getInput.accept(this) + override def visit(filter: relation.Filter, context: EmptyVisitationContext): LogicalPlan = { + val child = filter.getInput.accept(this, context) withChild(child) { - val condition = filter.getCondition.accept(expressionConverter) + val condition = filter.getCondition.accept(expressionConverter, context) Filter(condition, child) } } - override def visit(set: relation.Set): LogicalPlan = { - val children = set.getInputs.asScala.map(_.accept(this)) + override def visit(set: relation.Set, context: EmptyVisitationContext): LogicalPlan = { + val children = set.getInputs.asScala.map(_.accept(this, context)) withOutput(children.flatMap(_.output)) { set.getSetOp match { case SetOp.UNION_ALL => Union(children, byName = false, allowMissingCol = false) @@ -340,18 +345,22 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(emptyScan: relation.EmptyScan): LogicalPlan = { + override def visit( + emptyScan: relation.EmptyScan, + context: EmptyVisitationContext): LogicalPlan = { LocalRelation(ToSparkType.toAttributeSeq(emptyScan.getInitialSchema)) } - override def visit(virtualTableScan: relation.VirtualTableScan): LogicalPlan = { + override def visit( + virtualTableScan: relation.VirtualTableScan, + context: EmptyVisitationContext): LogicalPlan = { val rows = virtualTableScan.getRows.asScala.map( row => InternalRow.fromSeq( row .fields() .asScala - .map(field => field.accept(expressionConverter).asInstanceOf[Literal].value))) + .map(field => field.accept(expressionConverter, context).asInstanceOf[Literal].value))) virtualTableScan.getInitialSchema match { case ns: NamedStruct if ns.names().isEmpty && rows.length == 1 => OneRowRelation() @@ -360,14 +369,16 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } } - override def visit(namedScan: relation.NamedScan): LogicalPlan = { + override def visit( + namedScan: relation.NamedScan, + context: EmptyVisitationContext): LogicalPlan = { resolve(UnresolvedRelation(namedScan.getNames.asScala)) match { case m: MultiInstanceRelation => m.newInstance() case other => other } } - override def visit(localFiles: LocalFiles): LogicalPlan = { + override def visit(localFiles: LocalFiles, context: EmptyVisitationContext): LogicalPlan = { val schema = ToSparkType.toStructType(localFiles.getInitialSchema) val output = schema.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) @@ -438,7 +449,7 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate()) } def convert(rel: relation.Rel): LogicalPlan = { - val logicalPlan = rel.accept(this) + val logicalPlan = rel.accept(this, null) require(logicalPlan.resolved) logicalPlan } diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index daa047f95..ab6d9efef 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -575,7 +575,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { proto.PlanRel .newBuilder() .setRel(substraitRel - .accept(relProtoConverter)) + .accept(relProtoConverter, null)) ) extensionCollector.addExtensionsToPlan(builder) builder.build().toByteArray diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index 063dd94f1..491c05fa1 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -72,7 +72,7 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => // convert substrait back to spark plan val toLogicalPlan = new ToLogicalPlan(spark); - val sparkPlan2 = substraitRel2.accept(toLogicalPlan) + val sparkPlan2 = substraitRel2.accept(toLogicalPlan, null) require(sparkPlan2.resolved) // and back to substrait again diff --git a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala index 7232a116f..e649487cc 100644 --- a/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala +++ b/spark/src/test/scala/io/substrait/spark/TypesAndLiteralsSuite.scala @@ -101,7 +101,7 @@ class TypesAndLiteralsSuite extends SparkFunSuite { l => { test(s"test literal: $l (${l.dataType})") { val substraitLiteral = ToSubstraitLiteral.convert(l).get - val sparkLiteral = substraitLiteral.accept(toSparkExpression).asInstanceOf[Literal] + val sparkLiteral = substraitLiteral.accept(toSparkExpression, null).asInstanceOf[Literal] println("Before: " + l + " " + l.dataType) println("After: " + sparkLiteral + " " + sparkLiteral.dataType) @@ -118,7 +118,7 @@ class TypesAndLiteralsSuite extends SparkFunSuite { MapType(IntegerType, StringType, valueContainsNull = false)) val substraitLiteral = ToSubstraitLiteral.convert(l).get - val sparkLiteral = substraitLiteral.accept(toSparkExpression).asInstanceOf[Literal] + val sparkLiteral = substraitLiteral.accept(toSparkExpression, null).asInstanceOf[Literal] println("Before: " + l + " " + l.dataType) println("After: " + sparkLiteral + " " + sparkLiteral.dataType) diff --git a/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala index 45de335bc..fa0381b1b 100644 --- a/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/expression/SubstraitExpressionTestBase.scala @@ -48,7 +48,7 @@ trait SubstraitExpressionTestBase { f(substraitExp) if (bidirectional) { - val convertedExpression = substraitExp.accept(toSparkExpression) + val convertedExpression = substraitExp.accept(toSparkExpression, null) assertResult(expression)(convertedExpression) } }