diff --git a/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java b/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java index edf9ae50e18..54a429e4cfb 100644 --- a/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java +++ b/api/src/main/java/org/opensearch/sql/api/UnifiedQueryPlanner.java @@ -60,7 +60,15 @@ public UnifiedQueryPlanner(UnifiedQueryContext context) { */ public RelNode plan(String query) { try { - return context.measure(ANALYZE, () -> strategy.plan(query)); + return context.measure( + ANALYZE, + () -> { + RelNode plan = strategy.plan(query); + for (var shuttle : context.getLangSpec().postAnalysisRules()) { + plan = plan.accept(shuttle); + } + return plan; + }); } catch (SyntaxCheckException | UnsupportedOperationException e) { throw e; } catch (Exception e) { diff --git a/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java b/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java index 89167dc27a5..e824c89f8de 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/LanguageSpec.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; +import org.apache.calcite.rel.RelShuttle; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -17,8 +18,8 @@ /** * Language specification defining the dialect the engine accepts. Provides parser configuration, - * validator configuration, and composable {@link LanguageExtension}s that contribute operators and - * post-parse rewrite rules. + * validator configuration, and composable {@link LanguageExtension}s that contribute operators, + * post-parse rewrite rules, and post-analysis rewrite rules. * *

Implementations define a complete language surface — for example, {@link UnifiedSqlSpec} * provides ANSI and extended SQL modes. A future PPL spec would implement this same interface once @@ -27,8 +28,9 @@ public interface LanguageSpec { /** - * A composable language extension that contributes operators and post-parse rewrite rules. All - * methods have defaults so extensions only override what they need. + * A composable language extension that contributes operators, post-parse rewrite rules, and + * post-analysis rewrite rules. All methods have defaults so extensions only override what they + * need. */ interface LanguageExtension { @@ -47,6 +49,14 @@ default SqlOperatorTable operators() { default List> postParseRules() { return List.of(); } + + /** + * RelNode rewrite rules applied after analysis and before execution. Each rule transforms the + * logical plan tree. Rules within a single extension are applied in list order. + */ + default List postAnalysisRules() { + return List.of(); + } } /** @@ -62,9 +72,9 @@ default List> postParseRules() { SqlValidator.Config validatorConfig(); /** - * Language extensions registered with this spec. Each extension contributes operators and - * post-parse rewrite rules that are composed by {@link #operatorTable()} and {@link - * #postParseRules()}. + * Language extensions registered with this spec. Each extension contributes operators, post-parse + * rewrite rules, and post-analysis rewrite rules composed by {@link #operatorTable()}, {@link + * #postParseRules()}, and {@link #postAnalysisRules()}. */ List extensions(); @@ -86,4 +96,12 @@ default SqlOperatorTable operatorTable() { default List> postParseRules() { return extensions().stream().flatMap(ext -> ext.postParseRules().stream()).toList(); } + + /** + * All post-analysis RelNode rewrite rules from registered extensions, flattened in registration + * order. Applied to the logical plan after analysis and before execution. + */ + default List postAnalysisRules() { + return extensions().stream().flatMap(ext -> ext.postAnalysisRules().stream()).toList(); + } } diff --git a/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java b/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java index 763f6ded540..a34b9b98806 100644 --- a/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java +++ b/api/src/main/java/org/opensearch/sql/api/spec/UnifiedPplSpec.java @@ -10,6 +10,7 @@ import lombok.NoArgsConstructor; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.validate.SqlValidator; +import org.opensearch.sql.api.spec.datetime.DatetimeExtension; /** * PPL language specification. @@ -37,6 +38,6 @@ public SqlValidator.Config validatorConfig() { @Override public List extensions() { - return List.of(); + return List.of(new DatetimeExtension()); } } diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java new file mode 100644 index 00000000000..944ac4a4bf1 --- /dev/null +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeExtension.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.calcite.rel.RelShuttle; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.api.spec.LanguageSpec.LanguageExtension; +import org.opensearch.sql.calcite.type.AbstractExprRelDataType; +import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT; + +/** Datetime language extension that normalizes UDT types and casts output for wire-format. */ +public class DatetimeExtension implements LanguageExtension { + + @Override + public List postAnalysisRules() { + return List.of(DatetimeUdtNormalizeRule.INSTANCE, DatetimeOutputCastRule.INSTANCE); + } + + /** Maps datetime UDT types to their standard Calcite equivalents. */ + @Getter + @RequiredArgsConstructor + enum UdtMapping { + DATE(ExprUDT.EXPR_DATE, SqlTypeName.DATE), + TIME(ExprUDT.EXPR_TIME, SqlTypeName.TIME), + TIMESTAMP(ExprUDT.EXPR_TIMESTAMP, SqlTypeName.TIMESTAMP); + + private final ExprUDT udtType; + private final SqlTypeName stdType; + + /** Matches a UDT RelDataType to its mapping, or empty if not a datetime UDT. */ + static Optional fromUdtType(RelDataType type) { + if (!(type instanceof AbstractExprRelDataType e)) { + return Optional.empty(); + } + ExprUDT udt = e.getUdt(); + return Arrays.stream(values()).filter(u -> u.udtType == udt).findFirst(); + } + + /** Returns true if the given SqlTypeName is a standard datetime type. */ + static boolean isDatetimeType(SqlTypeName typeName) { + return Arrays.stream(values()).anyMatch(u -> u.stdType == typeName); + } + } +} diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeOutputCastRule.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeOutputCastRule.java new file mode 100644 index 00000000000..9a7ae25e003 --- /dev/null +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeOutputCastRule.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import static org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping.isDatetimeType; + +import java.util.ArrayList; +import java.util.List; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.type.SqlTypeName; + +/** Wraps the root output with CAST(datetime → VARCHAR) for PPL wire-format compatibility. */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +class DatetimeOutputCastRule extends RelHomogeneousShuttle { + + static final DatetimeOutputCastRule INSTANCE = new DatetimeOutputCastRule(); + + @Override + public RelNode visit(RelNode other) { + List fields = other.getRowType().getFieldList(); + if (fields.stream().noneMatch(f -> isDatetimeType(f.getType().getSqlTypeName()))) { + return other; + } + + RexBuilder rexBuilder = other.getCluster().getRexBuilder(); + List projects = new ArrayList<>(fields.size()); + List names = new ArrayList<>(fields.size()); + + // Cast datetime fields to VARCHAR for output; pass through others unchanged + for (RelDataTypeField field : fields) { + RexNode newField = rexBuilder.makeInputRef(other, field.getIndex()); + RelDataType fieldType = field.getType(); + if (isDatetimeType(fieldType.getSqlTypeName())) { + projects.add(castToVarchar(rexBuilder, newField, fieldType)); + } else { + projects.add(newField); + } + names.add(field.getName()); + } + return LogicalProject.create(other, List.of(), projects, names); + } + + private static RexNode castToVarchar(RexBuilder rexBuilder, RexNode expr, RelDataType fieldType) { + RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); + RelDataType varcharType = + typeFactory.createTypeWithNullability( + typeFactory.createSqlType(SqlTypeName.VARCHAR), fieldType.isNullable()); + return rexBuilder.makeCast(varcharType, expr); + } +} diff --git a/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java new file mode 100644 index 00000000000..b15d830d412 --- /dev/null +++ b/api/src/main/java/org/opensearch/sql/api/spec/datetime/DatetimeUdtNormalizeRule.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import java.util.Optional; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.sql.type.SqlTypeName; +import org.opensearch.sql.api.spec.datetime.DatetimeExtension.UdtMapping; + +/** + * Temporary patch that rewrites datetime UDT return types on RexCall nodes to standard Calcite + * types. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +class DatetimeUdtNormalizeRule extends RelHomogeneousShuttle { + + static final DatetimeUdtNormalizeRule INSTANCE = new DatetimeUdtNormalizeRule(); + + @Override + public RelNode visit(RelNode other) { + RelNode visited = super.visit(other); + RexBuilder rexBuilder = visited.getCluster().getRexBuilder(); + RelDataTypeFactory typeFactory = rexBuilder.getTypeFactory(); + return visited.accept( + new RexShuttle() { + @Override + public RexNode visitCall(RexCall call) { + call = (RexCall) super.visitCall(call); + Optional mapping = UdtMapping.fromUdtType(call.getType()); + if (mapping.isEmpty()) { + return call; + } + + // Normalize UDT return type to standard Calcite DATE/TIME/TIMESTAMP + UdtMapping m = mapping.get(); + SqlTypeName stdTypeName = m.getStdType(); + RelDataType baseType = + stdTypeName.allowsPrec() + ? typeFactory.createSqlType( + stdTypeName, typeFactory.getTypeSystem().getMaxPrecision(stdTypeName)) + : typeFactory.createSqlType(stdTypeName); + RelDataType stdType = + typeFactory.createTypeWithNullability(baseType, call.getType().isNullable()); + return call.clone(stdType, call.getOperands()); + } + }); + } +} diff --git a/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java new file mode 100644 index 00000000000..fc089150109 --- /dev/null +++ b/api/src/test/java/org/opensearch/sql/api/spec/datetime/DatetimeExtensionTest.java @@ -0,0 +1,225 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.api.spec.datetime; + +import static org.apache.calcite.sql.type.SqlTypeName.BIGINT; +import static org.apache.calcite.sql.type.SqlTypeName.DATE; +import static org.apache.calcite.sql.type.SqlTypeName.INTEGER; +import static org.apache.calcite.sql.type.SqlTypeName.TIME; +import static org.apache.calcite.sql.type.SqlTypeName.TIMESTAMP; +import static org.apache.calcite.sql.type.SqlTypeName.VARCHAR; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.calcite.rel.RelHomogeneousShuttle; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.schema.Table; +import org.apache.calcite.schema.impl.AbstractSchema; +import org.apache.calcite.sql.type.SqlTypeName; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.sql.api.ResultSetAssertion; +import org.opensearch.sql.api.UnifiedQueryContext; +import org.opensearch.sql.api.UnifiedQueryTestBase; +import org.opensearch.sql.api.compiler.UnifiedQueryCompiler; +import org.opensearch.sql.executor.QueryType; + +public class DatetimeExtensionTest extends UnifiedQueryTestBase implements ResultSetAssertion { + + private UnifiedQueryCompiler compiler; + + @Override + protected UnifiedQueryContext.Builder contextBuilder() { + return UnifiedQueryContext.builder() + .language(QueryType.PPL) + .catalog( + DEFAULT_CATALOG, + new AbstractSchema() { + @Override + protected Map getTableMap() { + return Map.of("events", createEventsTable()); + } + }); + } + + @Before + public void setUp() { + super.setUp(); + compiler = new UnifiedQueryCompiler(context); + } + + private Table createEventsTable() { + return SimpleTable.builder() + .col("id", INTEGER) + .col("name", VARCHAR) + .col("hire_date", DATE) + .col("start_time", TIME) + .col("created_at", TIMESTAMP) + .row(new Object[] {1, "Alice", 19738, 43200000, 1705305600000L}) + .row(new Object[] {2, "Bob", 19894, 50400000, 1718841600000L}) + .build(); + } + + @Test + public void testUdfResultNormalizedAndCastToVarchar() { + var plan = + givenQuery( + """ + source = catalog.events \ + | eval d = DATE(name), t = TIME(name), ts = TIMESTAMP(name) \ + | fields d, t, ts\ + """) + .assertPlan( + """ + LogicalProject(d=[CAST($0):VARCHAR], t=[CAST($1):VARCHAR], ts=[CAST($2):VARCHAR]) + LogicalProject(d=[DATE($1)], t=[TIME($1)], ts=[TIMESTAMP($1)]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + assertCallType(plan, "DATE", DATE); + assertCallType(plan, "TIME", TIME, 9); + assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + } + + @Test + public void testNestedUdfCallsNormalized() { + var plan = + givenQuery("source = catalog.events | eval d = DATEDIFF(DATE(name), DATE(name)) | fields d") + .assertPlan( + """ + LogicalProject(d=[DATEDIFF(DATE($1), DATE($1))]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + assertCallType(plan, "DATE", DATE); + assertCallType(plan, "DATEDIFF", BIGINT); + } + + @Test + public void testDateLiteralCastToVarchar() { + var plan = + givenQuery("source = catalog.events | eval d = DATE('2024-01-01') | fields d") + .assertPlan( + """ + LogicalProject(d=[CAST($0):VARCHAR]) + LogicalProject(d=[DATE('2024-01-01':VARCHAR)]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + assertCallType(plan, "DATE", DATE); + } + + @Test + public void testFilterWithTimestampLiteral() { + var plan = + givenQuery( + """ + source = catalog.events | where created_at > "2024-01-01T00:00:00Z" | fields id\ + """) + .assertPlan( + """ + LogicalProject(id=[$0]) + LogicalFilter(condition=[>($4, TIMESTAMP('2024-01-01T00:00:00Z':VARCHAR))]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + } + + @Test + public void testComparisonWithDatetimeUdf() { + var plan = + givenQuery("source = catalog.events | where created_at < DATE(name) | fields id") + .assertPlan( + """ + LogicalProject(id=[$0]) + LogicalFilter(condition=[<($4, TIMESTAMP(DATE($1)))]) + LogicalTableScan(table=[[catalog, events]]) + """) + .plan(); + assertCallType(plan, "DATE", DATE); + assertCallType(plan, "TIMESTAMP", TIMESTAMP, 9); + } + + @Test + public void testAllStandardDatetimeTypesCastToVarchar() { + givenQuery("source = catalog.events | fields hire_date, start_time, created_at") + .assertPlan( + """ + LogicalProject(hire_date=[CAST($0):VARCHAR NOT NULL], start_time=[CAST($1):VARCHAR NOT NULL], created_at=[CAST($2):VARCHAR NOT NULL]) + LogicalProject(hire_date=[$2], start_time=[$3], created_at=[$4]) + LogicalTableScan(table=[[catalog, events]]) + """); + } + + @Test + public void testNonDatetimeFieldsNotWrapped() { + givenQuery("source = catalog.events | fields id, name") + .assertPlan( + """ + LogicalProject(id=[$0], name=[$1]) + LogicalTableScan(table=[[catalog, events]]) + """); + } + + @Test + public void testOutputCastCanCompileAndExecute() throws Exception { + RelNode plan = + planner.plan("source = catalog.events | fields hire_date, start_time, created_at"); + try (PreparedStatement statement = compiler.compile(plan)) { + ResultSet resultSet = statement.executeQuery(); + verify(resultSet) + .expectSchema( + col("hire_date", java.sql.Types.VARCHAR), + col("start_time", java.sql.Types.VARCHAR), + col("created_at", java.sql.Types.VARCHAR)) + .expectData( + row("2024-01-16", "12:00:00", "2024-01-15 08:00:00"), + row("2024-06-20", "14:00:00", "2024-06-20 00:00:00")); + } + } + + private static void assertCallType(RelNode plan, String operatorName, SqlTypeName expectedType) { + assertCallType(plan, operatorName, expectedType, -1); + } + + private static void assertCallType( + RelNode plan, String operatorName, SqlTypeName expectedType, int expectedPrecision) { + AtomicReference ref = new AtomicReference<>(); + plan.accept( + new RelHomogeneousShuttle() { + @Override + public RelNode visit(RelNode other) { + RelNode visited = super.visit(other); + visited.accept( + new RexShuttle() { + @Override + public RexNode visitCall(RexCall call) { + if (ref.get() == null + && call.getOperator().getName().equalsIgnoreCase(operatorName)) { + ref.set(call); + } + return super.visitCall(call); + } + }); + return visited; + } + }); + assertNotNull("No RexCall found for: " + operatorName, ref.get()); + assertEquals(operatorName + " type", expectedType, ref.get().getType().getSqlTypeName()); + if (expectedPrecision >= 0) { + assertEquals( + operatorName + " precision", expectedPrecision, ref.get().getType().getPrecision()); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/OpenSearchTypeSystem.java b/core/src/main/java/org/opensearch/sql/executor/OpenSearchTypeSystem.java index b84d7dcf4d6..941f42de46c 100644 --- a/core/src/main/java/org/opensearch/sql/executor/OpenSearchTypeSystem.java +++ b/core/src/main/java/org/opensearch/sql/executor/OpenSearchTypeSystem.java @@ -22,6 +22,9 @@ public class OpenSearchTypeSystem extends RelDataTypeSystemImpl { // same with Spark DecimalType.MAX_SCALE public static int MAX_SCALE = 38; + /** Maximum fractional seconds precision for TIME and TIMESTAMP types (nanosecond). */ + public static final int MAX_DATETIME_PRECISION = 9; + private OpenSearchTypeSystem() {} @Override @@ -29,6 +32,20 @@ public int getMaxNumericPrecision() { return MAX_PRECISION; } + @Override + public int getMaxPrecision(SqlTypeName typeName) { + return switch (typeName) { + case TIME, + TIME_WITH_LOCAL_TIME_ZONE, + TIME_TZ, + TIMESTAMP, + TIMESTAMP_WITH_LOCAL_TIME_ZONE, + TIMESTAMP_TZ -> + MAX_DATETIME_PRECISION; + default -> super.getMaxPrecision(typeName); + }; + } + @Override public int getMaxNumericScale() { return MAX_SCALE;