Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,19 @@ public O visit(Expression.IfThen expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a Lambda expression.
*
* @param expr the Lambda expression
* @param context the visitation context
* @return the visit result
* @throws E if visitation fails
*/
@Override
public O visit(Expression.Lambda expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a scalar function invocation.
*
Expand Down
27 changes: 27 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,33 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class Lambda implements Expression {
public abstract Type.Struct parameters();

public abstract Expression body();

@Override
public Type getType() {
List<Type> paramTypes = parameters().fields();
Type returnType = body().getType();

// TO DO: fix Lambda return type once this issue
// https://github.com/substrait-io/substrait/issues/976 is resolved
return Type.withNullability(false).func(paramTypes, returnType);
}

public static ImmutableExpression.Lambda.Builder builder() {
return ImmutableExpression.Lambda.builder();
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* Base interface for user-defined literals.
*
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
*/
R visit(Expression.NestedStruct expr, C context) throws E;

/**
* Visit a Lambda expression.
*
* @param expr the Lambda expression
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.Lambda expr, C context) throws E;

/**
* Visit a user-defined any literal.
*
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/java/io/substrait/expression/FieldReference.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public abstract class FieldReference implements Expression {

public abstract Optional<Integer> outerReferenceStepsOut();

public abstract Optional<Integer> lambdaParameterReferenceStepsOut();

@Override
public Type getType() {
return type();
Expand All @@ -38,13 +40,18 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
public boolean isSimpleRootReference() {
return segments().size() == 1
&& !inputExpression().isPresent()
&& !outerReferenceStepsOut().isPresent();
&& !outerReferenceStepsOut().isPresent()
&& !lambdaParameterReferenceStepsOut().isPresent();
}

public boolean isOuterReference() {
return outerReferenceStepsOut().orElse(0) > 0;
}

public boolean isLambdaParameterReference() {
return lambdaParameterReferenceStepsOut().isPresent();
}

public FieldReference dereferenceStruct(int index) {
Type newType = StructFieldFinder.getReferencedType(type(), index);
return dereference(newType, StructField.of(index));
Expand Down Expand Up @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List<Rel> rels) {
index, currentOffset));
}

public static FieldReference newLambdaParameterReference(
int paramIndex, Type.Struct lambdaParamsType, int stepsOut) {
return ImmutableFieldReference.builder()
.addSegments(StructField.of(paramIndex))
.type(lambdaParamsType.fields().get(paramIndex))
.lambdaParameterReferenceStepsOut(stepsOut)
.build();
}

public interface ReferenceSegment {
FieldReference apply(FieldReference reference);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,18 @@ public Expression visit(
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context)
throws RuntimeException {
return io.substrait.proto.Expression.newBuilder()
.setLambda(
io.substrait.proto.Expression.Lambda.newBuilder()
.setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct())
.setBody(expr.body().accept(this, context)))
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
Expand Down Expand Up @@ -603,6 +615,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) {
out.setOuterReference(
io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder()
.setStepsOut(expr.outerReferenceStepsOut().get()));
} else if (expr.lambdaParameterReferenceStepsOut().isPresent()) {
out.setLambdaParameterReference(
io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder()
.setStepsOut(expr.lambdaParameterReferenceStepsOut().get()));
} else {
out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class ProtoExpressionConverter {
private final Type.Struct rootType;
private final ProtoTypeConverter protoTypeConverter;
private final ProtoRelConverter protoRelConverter;
private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack();

public ProtoExpressionConverter(
ExtensionLookup lookup,
Expand Down Expand Up @@ -75,6 +76,19 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
reference.getDirectReference().getStructField().getField(),
rootType,
reference.getOuterReference().getStepsOut());
case LAMBDA_PARAMETER_REFERENCE:
{
io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef =
reference.getLambdaParameterReference();

int stepsOut = lambdaParamRef.getStepsOut();
Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut);

return FieldReference.newLambdaParameterReference(
reference.getDirectReference().getStructField().getField(),
Copy link

@Slimsammylim Slimsammylim Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I'm curious about is why we use reference.getDirectReference().getStructField().getField() here instead of getDirectReferenceSegments(reference.getDirectReference()) like the ROOT_REFERENCE case does.

I'm unfamiliar with substrait-java, but does this current implementation support nested field access through lambda parameters?

*This may be an unimportant issue especially if it's not common to access nested fields in lambdas

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the ofRoot method takes the struct itself while the newLambdaParameterReference takes the index, like the newRootStructOuterReference methode

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I am still confused as to whether this supports nested field access through lambda parameters. I messed around w some tests through claude, and I believe it is unsupported. However, this is a nit and I'm not sure it will come up often in practice, so we can come back to this later if it ends up being an issue!

lambdaParameters,
stepsOut);
}
case ROOTTYPE_NOT_SET:
default:
throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase());
Expand Down Expand Up @@ -260,6 +274,27 @@ public Type visit(Type.Struct type) throws RuntimeException {
}
}

case LAMBDA:
{
io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda();
Type.Struct parameters =
(Type.Struct)
protoTypeConverter.from(
io.substrait.proto.Type.newBuilder()
.setStruct(protoLambda.getParameters())
.build());

lambdaParameterStack.push(parameters);

Expression body;
try {
body = from(protoLambda.getBody());
} finally {
lambdaParameterStack.pop();
}

return Expression.Lambda.builder().parameters(parameters).body(body).build();
}
// TODO enum.
case ENUM:
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());
Expand Down Expand Up @@ -574,4 +609,42 @@ public Expression.SortField fromSortField(SortField s) {
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
}

/**
* A stack for tracking lambda parameter types during expression parsing.
*
* <p>When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack.
* Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference:
*
* <ul>
* <li>stepsOut=0 refers to the innermost (current) lambda
* <li>stepsOut=1 refers to the next enclosing lambda
* <li>stepsOut=N refers to N levels up
* </ul>
*/
private static class LambdaParameterStack {
private final List<Type.Struct> stack = new ArrayList<>();

void push(Type.Struct parameters) {
stack.add(parameters);
}

void pop() {
if (stack.isEmpty()) {
throw new IllegalArgumentException("Lambda parameter stack is empty");
}
stack.remove(stack.size() - 1);
}

Type.Struct get(int stepsOut) {
int index = stack.size() - 1 - stepsOut;
if (index < 0 || index >= stack.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, stack.size()));
}
return stack.get(index);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog {
/** Extension identifier for set functions. */
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";

/** Extension identifier for list functions. */
public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list";

/** Extension identifier for string functions. */
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator<T, I> {
T listE(T type);

T mapE(T key, T value);

T funcE(Iterable<? extends T> parameterTypes, T returnType);
}
17 changes: 17 additions & 0 deletions core/src/main/java/io/substrait/function/ParameterizedType.java
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ <R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameter
}
}

@Value.Immutable
abstract class Func extends BaseParameterizedType implements NullableType {
public abstract java.util.List<ParameterizedType> parameterTypes();

public abstract ParameterizedType returnType();

public static ImmutableParameterizedType.Func.Builder builder() {
return ImmutableParameterizedType.Func.builder();
}

@Override
<R, E extends Throwable> R accept(final ParameterizedTypeVisitor<R, E> parameterizedTypeVisitor)
throws E {
return parameterizedTypeVisitor.visit(this);
}
}

@Value.Immutable
abstract class ListType extends BaseParameterizedType implements NullableType {
public abstract ParameterizedType name();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ public ParameterizedType listE(ParameterizedType type) {
return ParameterizedType.ListType.builder().nullable(nullable).name(type).build();
}

@Override
public ParameterizedType funcE(
Iterable<? extends ParameterizedType> parameterTypes, ParameterizedType returnType) {
return ParameterizedType.Func.builder()
.nullable(nullable)
.addAllParameterTypes(parameterTypes)
.returnType(returnType)
.build();
}

@Override
public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) {
return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor<R, E extends Throwable> extends TypeVi

R visit(ParameterizedType.StringLiteral stringLiteral) throws E;

R visit(ParameterizedType.Func expr) throws E;

abstract class ParameterizedTypeThrowsVisitor<R, E extends Throwable>
extends TypeVisitor.TypeThrowsVisitor<R, E> implements ParameterizedTypeVisitor<R, E> {

Expand Down Expand Up @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E {
public R visit(ParameterizedType.StringLiteral stringLiteral) throws E {
throw t();
}

@Override
public R visit(ParameterizedType.Func expr) throws E {
throw t();
}
}
}
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/function/TypeExpression.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ <R, E extends Throwable> R acceptE(final TypeExpressionVisitor<R, E> visitor) th
}
}

@Value.Immutable
abstract class Func extends BaseTypeExpression implements NullableType {
public abstract java.util.List<TypeExpression> parameterTypes();

public abstract TypeExpression returnType();

public static ImmutableTypeExpression.Func.Builder builder() {
return ImmutableTypeExpression.Func.builder();
}

@Override
<R, E extends Throwable> R acceptE(final TypeExpressionVisitor<R, E> visitor) throws E {
return visitor.visit(this);
}
}

@Value.Immutable
abstract class BinaryOperation extends BaseTypeExpression {
public enum OpType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) {
return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build();
}

@Override
public TypeExpression funcE(
Iterable<? extends TypeExpression> parameterTypes, TypeExpression returnType) {
return TypeExpression.Func.builder()
.nullable(nullable)
.addAllParameterTypes(parameterTypes)
.returnType(returnType)
.build();
}

public static class Assign {
String name;
TypeExpression expr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public interface TypeExpressionVisitor<R, E extends Throwable>

R visit(TypeExpression.Map expr) throws E;

R visit(TypeExpression.Func expr) throws E;

R visit(TypeExpression.BinaryOperation expr) throws E;

R visit(TypeExpression.NotOperation expr) throws E;
Expand Down Expand Up @@ -97,6 +99,11 @@ public R visit(TypeExpression.Map expr) throws E {
throw t();
}

@Override
public R visit(TypeExpression.Func expr) throws E {
throw t();
}

@Override
public R visit(TypeExpression.BinaryOperation expr) throws E {
throw t();
Expand Down
Loading