Skip to content

Add expressions support for Virtual Tables #417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ public OUTPUT visit(Expression.StructLiteral expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.StructNested expr) throws EXCEPTION {
return visitFallback(expr);
}

@Override
public OUTPUT visit(Expression.Switch expr) throws EXCEPTION {
return visitFallback(expr);
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,27 @@ public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws
}
}

@Value.Immutable
abstract static class StructNested implements Expression {
public abstract List<Expression> fields();

public Type getType() {
return Type.withNullability(false)
.struct(
fields().stream()
.map(Expression::getType)
.collect(java.util.stream.Collectors.toList()));
}

public static ImmutableExpression.StructNested.Builder builder() {
return ImmutableExpression.StructNested.builder();
}

public <R, E extends Throwable> R accept(ExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract static class UserDefinedLiteral implements Literal {
public abstract ByteString value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ public interface ExpressionVisitor<R, E extends Throwable> {

R visit(Expression.StructLiteral expr) throws E;

R visit(Expression.StructNested expr) throws E;

R visit(Expression.UserDefinedLiteral expr) throws E;

R visit(Expression.Switch expr) throws E;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ private Expression lit(Consumer<Expression.Literal.Builder> consumer) {
return Expression.newBuilder().setLiteral(builder).build();
}

private Expression nested(Consumer<Expression.Nested.Builder> consumer) {
var builder = Expression.Nested.newBuilder();
consumer.accept(builder);
return Expression.newBuilder().setNested(builder).build();
}

@Override
public Expression visit(io.substrait.expression.Expression.BoolLiteral expr) {
return lit(bldr -> bldr.setNullable(expr.nullable()).setBoolean(expr.value()));
Expand Down Expand Up @@ -323,6 +329,18 @@ public Expression visit(io.substrait.expression.Expression.StructLiteral expr) {
});
}

@Override
public Expression visit(io.substrait.expression.Expression.StructNested expr) {
return nested(
bldr -> {
var values =
expr.fields().stream()
.map(this::toProto)
.collect(java.util.stream.Collectors.toList());
bldr.setStruct(Expression.Nested.Struct.newBuilder().addAllFields(values));
});
}

@Override
public Expression visit(io.substrait.expression.Expression.UserDefinedLiteral expr) {
var typeReference =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ public Optional<Expression> visit(Expression.StructLiteral expr) throws EXCEPTIO
return visitLiteral(expr);
}

@Override
public Optional<Expression> visit(Expression.StructNested expr) throws EXCEPTION {
var expressions = visitExprList(expr.fields());
return expressions.map(
expressionList ->
Expression.StructNested.builder().from(expr).fields(expressionList).build());
}

@Override
public Optional<Expression> visit(Expression.UserDefinedLiteral expr) throws EXCEPTION {
return visitLiteral(expr);
Expand Down
30 changes: 23 additions & 7 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ public Rel from(io.substrait.proto.Rel rel) {
protected Rel newRead(ReadRel rel) {
if (rel.hasVirtualTable()) {
var virtualTable = rel.getVirtualTable();
if (virtualTable.getValuesCount() == 0) {
if (virtualTable.getValuesCount() == 0 && virtualTable.getExpressionsCount() == 0) {
return newEmptyScan(rel);
} else {
return newVirtualTable(rel);
Expand Down Expand Up @@ -417,17 +417,33 @@ protected FileOrFiles newFileOrFiles(ReadRel.LocalFiles.FileOrFiles file) {

protected VirtualTableScan newVirtualTable(ReadRel rel) {
var virtualTable = rel.getVirtualTable();
// If both values and expressions are set, raise an error
if (virtualTable.getValuesCount() > 0 && virtualTable.getExpressionsCount() > 0) {
throw new IllegalArgumentException(
"Virtual table cannot have both values and expressions set");
}

var virtualTableSchema = newNamedStruct(rel);

var converter =
new ProtoExpressionConverter(lookup, extensions, virtualTableSchema.struct(), this);
List<Expression.StructLiteral> structLiterals = new ArrayList<>(virtualTable.getValuesCount());

List<Expression> expressions =
new ArrayList<>(virtualTable.getValuesCount() + virtualTable.getExpressionsCount());

for (var struct : virtualTable.getValuesList()) {
structLiterals.add(
expressions.add(
ImmutableExpression.StructLiteral.builder()
.fields(
struct.getFieldsList().stream()
.map(converter::from)
.collect(java.util.stream.Collectors.toList()))
struct.getFieldsList().stream().map(converter::from).collect(Collectors.toList()))
.build());
}

for (var expr : virtualTable.getExpressionsList()) {
expressions.add(
ImmutableExpression.StructNested.builder()
.fields(
expr.getFieldsList().stream().map(converter::from).collect(Collectors.toList()))
.build());
}

Expand All @@ -438,7 +454,7 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) {
rel.hasBestEffortFilter() ? converter.from(rel.getBestEffortFilter()) : null))
.filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null))
.initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter))
.rows(structLiterals);
.rows(expressions);

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import java.util.List;
import java.util.Objects;
import org.immutables.value.Value;

@Value.Immutable
public abstract class VirtualTableScan extends AbstractReadRel {

public abstract List<Expression.StructLiteral> getRows();
public abstract List<Expression> getRows();

/**
*
Expand All @@ -29,9 +30,9 @@ protected void check() {
== NamedFieldCountingTypeVisitor.countNames(this.getInitialSchema().struct());
var rows = getRows();

assert rows.size() > 0
&& names.stream().noneMatch(s -> s == null)
&& rows.stream().noneMatch(r -> r == null)
assert !rows.isEmpty()
&& names.stream().noneMatch(Objects::isNull)
&& rows.stream().noneMatch(Objects::isNull)
&& rows.stream()
.allMatch(r -> NamedFieldCountingTypeVisitor.countNames(r.getType()) == names.size());
}
Expand Down
28 changes: 28 additions & 0 deletions core/src/test/java/io/substrait/relation/VirtualTableScanTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package io.substrait.relation;

import static io.substrait.expression.ExpressionCreator.bool;
import static io.substrait.expression.ExpressionCreator.fp32;
import static io.substrait.expression.ExpressionCreator.fp64;
import static io.substrait.expression.ExpressionCreator.i8;
import static io.substrait.expression.ExpressionCreator.i16;
import static io.substrait.expression.ExpressionCreator.i32;
import static io.substrait.expression.ExpressionCreator.i64;
import static io.substrait.expression.ExpressionCreator.list;
import static io.substrait.expression.ExpressionCreator.map;
import static io.substrait.expression.ExpressionCreator.string;
Expand All @@ -25,6 +32,13 @@ void check() {
NamedStruct.of(
Arrays.stream(
new String[] {
"bool_field",
"i8_field",
"i16_field",
"i32_field",
"i64_field",
"fp32_field",
"fp64_field",
"string",
"struct",
"struct_field1",
Expand All @@ -37,13 +51,27 @@ void check() {
})
.collect(Collectors.toList()),
R.struct(
R.BOOLEAN,
R.I8,
R.I16,
R.I32,
R.I64,
R.FP32,
R.FP64,
R.STRING,
R.struct(R.STRING, R.STRING),
R.list(R.struct(R.STRING)),
R.map(R.struct(R.STRING), R.struct(R.STRING)))))
.addRows(
struct(
false,
bool(false, true),
i8(false, 42),
i16(false, 1234),
i32(false, 123456),
i64(false, 9876543210L),
fp32(false, 3.14f),
fp64(false, 2.718281828),
string(false, "string_val"),
struct(
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,21 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
}

override def visit(virtualTableScan: relation.VirtualTableScan): LogicalPlan = {
val rows = virtualTableScan.getRows.asScala.map(
row =>
val rows = virtualTableScan.getRows.asScala.map {
case structLit: SExpression.StructLiteral =>
InternalRow.fromSeq(
row
.fields()
.asScala
.map(field => field.accept(expressionConverter).asInstanceOf[Literal].value)))
structLit.fields.asScala
.map(field => field.accept(expressionConverter).asInstanceOf[Literal].value)
)
case structNested: SExpression.StructNested =>
InternalRow.fromSeq(
structNested.fields.asScala
.map(expr => expr.accept(expressionConverter))
)
case other =>
throw new UnsupportedOperationException(
s"Unsupported row type in VirtualTableScan: ${other.getClass}")
}
virtualTableScan.getInitialSchema match {
case ns: NamedStruct if ns.names().isEmpty && rows.length == 1 =>
OneRowRelation()
Expand Down