Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions isthmus/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ dependencies {
testImplementation(platform(libs.junit.bom))
testImplementation(libs.junit.jupiter)
testRuntimeOnly(libs.junit.platform.launcher)
testRuntimeOnly(libs.slf4j.jdk14)
implementation(libs.guava)
implementation(libs.protobuf.java.util) {
exclude("com.google.guava", "guava")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package io.substrait.isthmus;

import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.expression.AggregateFunctionConverter;
import io.substrait.isthmus.expression.FunctionMappings;
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorTable;
import org.apache.calcite.sql.util.SqlOperatorTables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AutomaticDynamicFunctionMappingConverterProvider extends ConverterProvider {

private static final Logger LOGGER =
LoggerFactory.getLogger(AutomaticDynamicFunctionMappingConverterProvider.class);

public AutomaticDynamicFunctionMappingConverterProvider() {
this(DefaultExtensionCatalog.DEFAULT_COLLECTION, SubstraitTypeSystem.TYPE_FACTORY);
}

public AutomaticDynamicFunctionMappingConverterProvider(SimpleExtension.ExtensionCollection extensions) {
this(extensions, SubstraitTypeSystem.TYPE_FACTORY);
}

public AutomaticDynamicFunctionMappingConverterProvider(
SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) {
super(extensions, typeFactory);
this.scalarFunctionConverter = createScalarFunctionConverter();
this.aggregateFunctionConverter = createAggregateFunctionConverter();
this.windowFunctionConverter = createWindowFunctionConverter();
}

@Override
public SqlOperatorTable getSqlOperatorTable() {
SqlOperatorTable baseOperatorTable = super.getSqlOperatorTable();
List<SqlOperator> dynamicOperators = new ArrayList<>();

List<SimpleExtension.ScalarFunctionVariant> unmappedScalars =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.scalarFunctions(),
io.substrait.isthmus.expression.FunctionMappings.SCALAR_SIGS);
List<SimpleExtension.AggregateFunctionVariant> unmappedAggregates =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.aggregateFunctions(),
io.substrait.isthmus.expression.FunctionMappings.AGGREGATE_SIGS);
List<SimpleExtension.WindowFunctionVariant> unmappedWindows =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.windowFunctions(),
io.substrait.isthmus.expression.FunctionMappings.WINDOW_SIGS);

if (!unmappedScalars.isEmpty()) {
dynamicOperators.addAll(SimpleExtensionToSqlOperator.from(unmappedScalars, typeFactory));
}
if (!unmappedAggregates.isEmpty()) {
dynamicOperators.addAll(SimpleExtensionToSqlOperator.from(unmappedAggregates, typeFactory));
}
if (!unmappedWindows.isEmpty()) {
dynamicOperators.addAll(SimpleExtensionToSqlOperator.from(unmappedWindows, typeFactory));
}

if (!dynamicOperators.isEmpty()) {
return SqlOperatorTables.chain(baseOperatorTable, SqlOperatorTables.of(dynamicOperators));
} else {
return baseOperatorTable;
}
}

protected ScalarFunctionConverter createScalarFunctionConverter() {
List<SimpleExtension.ScalarFunctionVariant> unmappedFunctions =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.scalarFunctions(), FunctionMappings.SCALAR_SIGS);

List<FunctionMappings.Sig> additionalSignatures = new ArrayList<>();

if (!unmappedFunctions.isEmpty()) {
LOGGER.info(
"Dynamically mapping {} unmapped scalar functions: {}",
unmappedFunctions.size(),
unmappedFunctions.stream().map(f -> f.name()).collect(Collectors.toList()));

List<SqlOperator> dynamicOperators =
SimpleExtensionToSqlOperator.from(unmappedFunctions, typeFactory);

java.util.Map<String, SqlOperator> operatorsByName = new java.util.LinkedHashMap<>();
for (SqlOperator op : dynamicOperators) {
operatorsByName.put(op.getName().toLowerCase(), op);
}

additionalSignatures.addAll(
operatorsByName.values().stream()
.map(op -> FunctionMappings.s(op, op.getName().toLowerCase()))
.collect(Collectors.toList()));
}

return new ScalarFunctionConverter(
extensions.scalarFunctions(), additionalSignatures, typeFactory, typeConverter);
}

protected AggregateFunctionConverter createAggregateFunctionConverter() {
List<FunctionMappings.Sig> additionalSignatures = new ArrayList<>();

List<SimpleExtension.AggregateFunctionVariant> unmappedFunctions =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.aggregateFunctions(), FunctionMappings.AGGREGATE_SIGS);

if (!unmappedFunctions.isEmpty()) {
List<SqlOperator> dynamicOperators =
SimpleExtensionToSqlOperator.from(unmappedFunctions, typeFactory);

java.util.Map<String, SqlOperator> operatorsByName = new java.util.LinkedHashMap<>();
for (SqlOperator op : dynamicOperators) {
operatorsByName.put(op.getName().toLowerCase(), op);
}

additionalSignatures.addAll(
operatorsByName.values().stream()
.map(op -> FunctionMappings.s(op, op.getName().toLowerCase()))
.collect(Collectors.toList()));
}

return new AggregateFunctionConverter(
extensions.aggregateFunctions(), additionalSignatures, typeFactory, typeConverter);
}

protected WindowFunctionConverter createWindowFunctionConverter() {
List<FunctionMappings.Sig> additionalSignatures = new ArrayList<>();

List<SimpleExtension.WindowFunctionVariant> unmappedFunctions =
io.substrait.isthmus.expression.FunctionConverter.getUnmappedFunctions(
extensions.windowFunctions(), FunctionMappings.WINDOW_SIGS);

if (!unmappedFunctions.isEmpty()) {
List<SqlOperator> dynamicOperators =
SimpleExtensionToSqlOperator.from(unmappedFunctions, typeFactory);

java.util.Map<String, SqlOperator> operatorsByName = new java.util.LinkedHashMap<>();
for (SqlOperator op : dynamicOperators) {
operatorsByName.put(op.getName().toLowerCase(), op);
}

additionalSignatures.addAll(
operatorsByName.values().stream()
.map(op -> FunctionMappings.s(op, op.getName().toLowerCase()))
.collect(Collectors.toList()));
}

return new WindowFunctionConverter(
extensions.windowFunctions(), additionalSignatures, typeFactory, typeConverter);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,45 @@ public static List<SqlOperator> from(
SimpleExtension.ExtensionCollection collection,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
// TODO: add support for windows functions
return Stream.concat(
collection.scalarFunctions().stream(), collection.aggregateFunctions().stream())
Stream.concat(
collection.scalarFunctions().stream(), collection.aggregateFunctions().stream()),
collection.windowFunctions().stream())
.map(function -> toSqlFunction(function, typeFactory, typeConverter))
.collect(Collectors.toList());
}

/**
* Converts a list of functions to SqlOperators. Handles scalar, aggregate, and window functions.
*
* @param functions list of functions to convert
* @param typeFactory the Calcite type factory
* @return list of SqlOperators
*/
public static List<SqlOperator> from(
List<? extends SimpleExtension.Function> functions, RelDataTypeFactory typeFactory) {
return from(functions, typeFactory, TypeConverter.DEFAULT);
}

/**
* Converts a list of functions to SqlOperators. Handles scalar, aggregate, and window functions.
*
* <p>Each function variant is converted to a separate SqlOperator. Functions with the same base
* name but different type signatures (e.g., strftime:ts_str, strftime:ts_string) are ALL added to
* the operator table. Calcite will try to match the function call arguments against all available
* operators and select the one that matches. This allows functions with multiple signatures to be
* used correctly without explicit deduplication.
*
* @param functions list of functions to convert
* @param typeFactory the Calcite type factory
* @param typeConverter the type converter
* @return list of SqlOperators
*/
public static List<SqlOperator> from(
List<? extends SimpleExtension.Function> functions,
RelDataTypeFactory typeFactory,
TypeConverter typeConverter) {
return functions.stream()
.map(function -> toSqlFunction(function, typeFactory, typeConverter))
.collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,30 @@ private Optional<F> signatureMatch(List<Type> inputTypes, Type outputType) {
for (F function : functions) {
List<SimpleExtension.Argument> args = function.requiredArguments();
// Make sure that arguments & return are within bounds and match the types
if (function.returnType() instanceof ParameterizedType
&& isMatch(outputType, (ParameterizedType) function.returnType())
&& inputTypesMatchDefinedArguments(inputTypes, args)) {
boolean returnTypeMatches;
Object funcReturnType = function.returnType();

if (funcReturnType instanceof ParameterizedType) {
returnTypeMatches = isMatch(outputType, (ParameterizedType) funcReturnType);
} else if (funcReturnType instanceof Type) {
// For non-parameterized return types, check if they match
Type targetType = (Type) funcReturnType;
if (outputType instanceof ParameterizedType) {
// outputType is parameterized but targetType is not - use visitor pattern
returnTypeMatches =
((ParameterizedType) outputType)
.accept(new IgnoreNullableAndParameters(targetType));
} else {
// Both are non-parameterized types - compare them directly by using the visitor
// Create a simple visitor that just checks class equality
returnTypeMatches = outputType.getClass().equals(targetType.getClass());
}
} else {
// If function.returnType() is neither Type nor ParameterizedType, skip it
returnTypeMatches = false;
}

if (returnTypeMatches && inputTypesMatchDefinedArguments(inputTypes, args)) {
return Optional.of(function);
}
}
Expand Down Expand Up @@ -476,6 +497,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
if (leastRestrictive.isPresent()) {
return leastRestrictive;
}
} else {
// Fallback: try matchCoerced even if singularInputType is empty
// This handles functions with mixed argument types like strftime(timestamp, string)
Optional<T> coerced = matchCoerced(call, outputType, operands);
if (coerced.isPresent()) {
return coerced;
}
}
return Optional.empty();
}
Expand Down Expand Up @@ -565,4 +593,25 @@ private static boolean isMatch(ParameterizedType actualType, ParameterizedType t
}
return actualType.accept(new IgnoreNullableAndParameters(targetType));
}

/**
* Identifies functions that are not mapped in the provided Sig list.
*
* @param functions the list of function variants to check
* @param sigs the list of mapped Sig signatures
* @return a list of functions that are not found in the Sig mappings (case-insensitive name
* comparison)
*/
public static <F extends SimpleExtension.Function> List<F> getUnmappedFunctions(
List<F> functions, ImmutableList<FunctionMappings.Sig> sigs) {
Set<String> mappedNames =
sigs.stream()
.map(FunctionMappings.Sig::name)
.map(name -> name.toLowerCase(Locale.ROOT))
.collect(Collectors.toSet());

return functions.stream()
.filter(fn -> !mappedNames.contains(fn.name().toLowerCase(Locale.ROOT)))
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ public Boolean visit(Type.FP64 type) {

@Override
public Boolean visit(Type.Str type) {
return typeToMatch instanceof Type.Str;
// Treat all string types as compatible: Str, VarChar, and FixedChar
return typeToMatch instanceof Type.Str
|| typeToMatch instanceof Type.VarChar
|| typeToMatch instanceof Type.FixedChar
|| typeToMatch instanceof ParameterizedType.VarChar
|| typeToMatch instanceof ParameterizedType.FixedChar;
}

@Override
Expand Down Expand Up @@ -108,13 +113,22 @@ public Boolean visit(Type.UserDefined type) throws RuntimeException {

@Override
public Boolean visit(Type.FixedChar type) {
// Treat all string types as compatible: Str, VarChar, and FixedChar
return typeToMatch instanceof Type.FixedChar
|| typeToMatch instanceof ParameterizedType.FixedChar;
|| typeToMatch instanceof ParameterizedType.FixedChar
|| typeToMatch instanceof Type.Str
|| typeToMatch instanceof Type.VarChar
|| typeToMatch instanceof ParameterizedType.VarChar;
}

@Override
public Boolean visit(Type.VarChar type) {
return typeToMatch instanceof Type.VarChar || typeToMatch instanceof ParameterizedType.VarChar;
// Treat all string types as compatible: Str, VarChar, and FixedChar
return typeToMatch instanceof Type.VarChar
|| typeToMatch instanceof ParameterizedType.VarChar
|| typeToMatch instanceof Type.Str
|| typeToMatch instanceof Type.FixedChar
|| typeToMatch instanceof ParameterizedType.FixedChar;
}

@Override
Expand All @@ -131,18 +145,21 @@ public Boolean visit(Type.Decimal type) {
@Override
public Boolean visit(Type.PrecisionTime type) {
return typeToMatch instanceof Type.PrecisionTime
|| typeToMatch instanceof Type.Time
|| typeToMatch instanceof ParameterizedType.PrecisionTime;
}

@Override
public Boolean visit(Type.PrecisionTimestamp type) {
return typeToMatch instanceof Type.PrecisionTimestamp
|| typeToMatch instanceof Type.Timestamp
|| typeToMatch instanceof ParameterizedType.PrecisionTimestamp;
}

@Override
public Boolean visit(Type.PrecisionTimestampTZ type) {
return typeToMatch instanceof Type.PrecisionTimestampTZ
|| typeToMatch instanceof Type.TimestampTZ
|| typeToMatch instanceof ParameterizedType.PrecisionTimestampTZ;
}

Expand All @@ -164,13 +181,22 @@ public Boolean visit(Type.Map type) {

@Override
public Boolean visit(ParameterizedType.FixedChar expr) throws RuntimeException {
// Treat all string types as compatible: Str, VarChar, and FixedChar
return typeToMatch instanceof Type.FixedChar
|| typeToMatch instanceof ParameterizedType.FixedChar;
|| typeToMatch instanceof ParameterizedType.FixedChar
|| typeToMatch instanceof Type.Str
|| typeToMatch instanceof Type.VarChar
|| typeToMatch instanceof ParameterizedType.VarChar;
}

@Override
public Boolean visit(ParameterizedType.VarChar expr) throws RuntimeException {
return typeToMatch instanceof Type.VarChar || typeToMatch instanceof ParameterizedType.VarChar;
// Treat all string types as compatible: Str, VarChar, and FixedChar
return typeToMatch instanceof Type.VarChar
|| typeToMatch instanceof ParameterizedType.VarChar
|| typeToMatch instanceof Type.Str
|| typeToMatch instanceof Type.FixedChar
|| typeToMatch instanceof ParameterizedType.FixedChar;
}

@Override
Expand Down Expand Up @@ -199,18 +225,21 @@ public Boolean visit(ParameterizedType.IntervalCompound expr) throws RuntimeExce
@Override
public Boolean visit(ParameterizedType.PrecisionTime expr) throws RuntimeException {
return typeToMatch instanceof Type.PrecisionTime
|| typeToMatch instanceof Type.Time
|| typeToMatch instanceof ParameterizedType.PrecisionTime;
}

@Override
public Boolean visit(ParameterizedType.PrecisionTimestamp expr) throws RuntimeException {
return typeToMatch instanceof Type.PrecisionTimestamp
|| typeToMatch instanceof Type.Timestamp
|| typeToMatch instanceof ParameterizedType.PrecisionTimestamp;
}

@Override
public Boolean visit(ParameterizedType.PrecisionTimestampTZ expr) throws RuntimeException {
return typeToMatch instanceof Type.PrecisionTimestampTZ
|| typeToMatch instanceof Type.TimestampTZ
|| typeToMatch instanceof ParameterizedType.PrecisionTimestampTZ;
}

Expand Down
Loading
Loading