Skip to content

Commit b6db474

Browse files
committed
enforce null literal is nullable
1 parent ee8c121 commit b6db474

5 files changed

Lines changed: 105 additions & 14 deletions

File tree

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,26 @@ <R, C extends VisitationContext, E extends Throwable> R accept(
5353
abstract class NullLiteral implements Literal {
5454
public abstract Type type();
5555

56+
/** A null literal is inherently nullable - you can't have a null of a non-nullable type. */
57+
@Override
58+
public boolean nullable() {
59+
return true;
60+
}
61+
62+
/** Returns this literal unchanged since null literals are always nullable. */
63+
@Override
64+
public NullLiteral withNullable(boolean nullable) {
65+
return this;
66+
}
67+
68+
@Value.Check
69+
protected void check() {
70+
if (!type().nullable()) {
71+
throw new IllegalArgumentException(
72+
"NullLiteral requires a nullable type, but got: " + type());
73+
}
74+
}
75+
5676
@Override
5777
public Type getType() {
5878
return type();

core/src/test/java/io/substrait/type/proto/LiteralRoundtripTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static io.substrait.expression.proto.ProtoExpressionConverter.EMPTY_TYPE;
44
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
import static org.junit.jupiter.api.Assertions.assertThrows;
56

67
import com.google.protobuf.Any;
78
import io.substrait.TestBase;
@@ -287,4 +288,13 @@ void userDefinedLiteralWithAllParameterTypes() {
287288

288289
verifyNestedTypesRoundTrip(multiParam);
289290
}
291+
292+
@Test
293+
void nullLiteralRejectsNonNullableType() {
294+
io.substrait.type.Type nonNullableType =
295+
io.substrait.type.Type.I32.builder().nullable(false).build();
296+
297+
assertThrows(
298+
IllegalArgumentException.class, () -> ExpressionCreator.typedNull(nonNullableType));
299+
}
290300
}

isthmus/src/main/java/io/substrait/isthmus/expression/CallConverters.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,11 @@ public class CallConverters {
146146
Expression.Literal lit = (Expression.Literal) operands.get(i);
147147
boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable();
148148

149-
// ROW types are never nullable (struct literals are always concrete values).
150-
// Field nullability comes from individual field types.
151-
if (fieldIsNullable && !lit.nullable()) {
152-
lit = lit.withNullable(true);
153-
}
149+
// Calcite's RexBuilder.makeLiteral() always strips nullability from literal types,
150+
// because a concrete value (like 42) is never null. However, Substrait tracks
151+
// nullability at the field level to indicate what values that position can hold.
152+
// We restore the correct nullability from the ROW's field type schema.
153+
lit = lit.withNullable(fieldIsNullable);
154154
literals.add(lit);
155155
}
156156

isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,7 @@ void tStructRoundtripMixedFieldNullability() {
419419
void tStructRoundtripWithNullFieldValues() {
420420
// Test struct with actual NULL field values roundtrips correctly
421421
Expression.NullLiteral nullField =
422-
Expression.NullLiteral.builder()
423-
.nullable(true)
424-
.type(io.substrait.type.Type.I32.builder().nullable(true).build())
425-
.build();
422+
ExpressionCreator.typedNull(io.substrait.type.Type.I32.builder().nullable(true).build());
426423

427424
Expression.StructLiteral struct =
428425
ExpressionCreator.struct(false, nullField, ExpressionCreator.i32(false, 100));

isthmus/src/test/java/io/substrait/isthmus/UserDefinedLiteralRoundtripTest.java

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,15 @@ void parameterizedUdtAllParamKindsRoundTrip() {
270270
assertRoundTrip(literal);
271271
}
272272

273+
/**
274+
* Test field-level nullability: struct is non-nullable, but fields are nullable.
275+
*
276+
* <p>Note: The YAML schema declares point fields as non-nullable (i32), but we're using nullable
277+
* fields here. This works because we don't currently validate against the schema. See
278+
* https://github.com/substrait-io/substrait-java/issues/614
279+
*/
273280
@Test
274281
void nullableFieldsInStructUdtRoundTrip() {
275-
// Test field-level nullability: struct is non-nullable, but fields are nullable
276282
Expression.UserDefinedStructLiteral literal =
277283
ExpressionCreator.userDefinedLiteralStruct(
278284
false,
@@ -286,10 +292,16 @@ void nullableFieldsInStructUdtRoundTrip() {
286292
assertRoundTrip(literal);
287293
}
288294

295+
/**
296+
* Test mixed field nullability: struct is non-nullable, first field nullable, second
297+
* non-nullable.
298+
*
299+
* <p>Note: The YAML schema declares point fields as non-nullable (i32), but we're using a
300+
* nullable field here. This works because we don't currently validate against the schema. See
301+
* https://github.com/substrait-io/substrait-java/issues/614
302+
*/
289303
@Test
290304
void mixedFieldNullabilityInStructUdtRoundTrip() {
291-
// Test mixed field nullability: struct is non-nullable, first field nullable, second
292-
// non-nullable
293305
Expression.UserDefinedStructLiteral literal =
294306
ExpressionCreator.userDefinedLiteralStruct(
295307
false,
@@ -303,9 +315,9 @@ void mixedFieldNullabilityInStructUdtRoundTrip() {
303315
assertRoundTrip(literal);
304316
}
305317

318+
/** Test struct-level nullability: struct is nullable, fields are non-nullable. */
306319
@Test
307320
void nullableStructEncodedUdtRoundTrip() {
308-
// Test struct-level nullability: struct is nullable, fields are non-nullable
309321
Expression.UserDefinedStructLiteral literal =
310322
ExpressionCreator.userDefinedLiteralStruct(
311323
true,
@@ -319,9 +331,15 @@ void nullableStructEncodedUdtRoundTrip() {
319331
assertRoundTrip(literal);
320332
}
321333

334+
/**
335+
* Test the critical case: nullable struct with mixed field nullability.
336+
*
337+
* <p>Note: The YAML schema declares point fields as non-nullable (i32), but we're using a
338+
* nullable field here. This works because we don't currently validate against the schema. See
339+
* https://github.com/substrait-io/substrait-java/issues/614
340+
*/
322341
@Test
323342
void nullableStructWithMixedFieldNullabilityRoundTrip() {
324-
// Test the critical case: nullable struct with mixed field nullability
325343
Expression.UserDefinedStructLiteral literal =
326344
ExpressionCreator.userDefinedLiteralStruct(
327345
true,
@@ -396,4 +414,50 @@ void listAndMapFieldsInStructUdtRoundTrip() {
396414

397415
assertRoundTrip(literal);
398416
}
417+
418+
/**
419+
* Test with an actual null value (NullLiteral) in a struct field.
420+
*
421+
* <p>Note: The YAML schema declares point fields as non-nullable (i32), but we're using nullable
422+
* fields here. This works because we don't currently validate against the schema. When
423+
* https://github.com/substrait-io/substrait-java/issues/614 is implemented, this test will need
424+
* to be updated to use a UDT with nullable fields in its schema.
425+
*/
426+
@Test
427+
void nullValueInStructUdtFieldRoundTrip() {
428+
Expression.UserDefinedStructLiteral literal =
429+
ExpressionCreator.userDefinedLiteralStruct(
430+
false,
431+
NESTED_TYPES_URN,
432+
"point",
433+
Collections.emptyList(),
434+
Arrays.asList(
435+
ExpressionCreator.typedNull(Type.I32.builder().nullable(true).build()),
436+
ExpressionCreator.i32(true, 100)));
437+
438+
assertRoundTrip(literal);
439+
}
440+
441+
/**
442+
* Test with a mix of null and non-null values in struct fields.
443+
*
444+
* <p>Note: The YAML schema declares point fields as non-nullable (i32), but we're using nullable
445+
* fields here. This works because we don't currently validate against the schema. When
446+
* https://github.com/substrait-io/substrait-java/issues/614 is implemented, this test will need
447+
* to be updated to use a UDT with nullable fields in its schema.
448+
*/
449+
@Test
450+
void mixedNullAndNonNullValuesInStructUdtRoundTrip() {
451+
Expression.UserDefinedStructLiteral literal =
452+
ExpressionCreator.userDefinedLiteralStruct(
453+
false,
454+
NESTED_TYPES_URN,
455+
"point",
456+
Collections.emptyList(),
457+
Arrays.asList(
458+
ExpressionCreator.i32(true, 42),
459+
ExpressionCreator.typedNull(Type.I32.builder().nullable(true).build())));
460+
461+
assertRoundTrip(literal);
462+
}
399463
}

0 commit comments

Comments
 (0)