diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/FilterPredicateImpl.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/FilterPredicateImpl.java index 17409cb3e..745a3453a 100644 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/FilterPredicateImpl.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/FilterPredicateImpl.java @@ -56,8 +56,7 @@ public SerializableFunction, Object> deriveOperation(final SqlKind case OR -> input.stream().anyMatch(obj -> Boolean.class.cast(obj).booleanValue()); case MINUS -> widenToDouble.apply(input.get(0)) - widenToDouble.apply(input.get(1)); case PLUS -> widenToDouble.apply(input.get(0)) + widenToDouble.apply(input.get(1)); - // TODO: may need better support for CASTing in the future. See sqlCast() in this file. - case CAST -> input.get(0) instanceof Number ? widenToDouble.apply(input.get(0)) : ensureComparable.apply(input.get(0)); + case CAST -> SqlRuntimeCast.castValue(input.get(0), returnType); case SEARCH -> { if (input.get(0) instanceof final ImmutableRangeSet range) { assert input.get(1) instanceof Comparable @@ -84,16 +83,6 @@ public SerializableFunction, Object> deriveOperation(final SqlKind }; } - /** - * Java implementation of SQL cast. - * @param input input field - * @param type the new return type of the field - * @return Java-type equivalent to {@link SqlTypeName} counterpart. - */ - private static Object sqlCast(Object input, SqlTypeName type){ - throw new UnsupportedOperationException("sqlCasting is not yet implemented."); - } - /** * Java equivalent of SQL like clauses * diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/SqlRuntimeCast.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/SqlRuntimeCast.java new file mode 100644 index 000000000..667e13404 --- /dev/null +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/SqlRuntimeCast.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.api.sql.calcite.converter.functions; + +import java.math.BigDecimal; +import java.util.Calendar; +import java.util.Date; + +import org.apache.calcite.runtime.SqlFunctions; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.DateString; +import org.apache.calcite.util.NlsString; + +/** + * Runtime SQL {@code CAST} for Wayang Java filter evaluation, delegating to + * {@link SqlFunctions} where possible. + */ +public final class SqlRuntimeCast { + + private SqlRuntimeCast() {} + + /** + * @param input evaluated operand (SQL NULL is {@code null}) + * @param target destination SQL type name of the cast (from the RexCall result type) + * @return value suitable for comparisons and filter logic + */ + public static Object castValue(final Object input, final SqlTypeName target) { + if (input == null) { + return null; + } + final Object v = unwrapForCast(input); + switch (target) { + case BOOLEAN: + return SqlFunctions.toBoolean(v); + case TINYINT: + return SqlFunctions.toByte(v); + case SMALLINT: + return SqlFunctions.toShort(v); + case INTEGER: + return SqlFunctions.toInt(v); + case BIGINT: + return SqlFunctions.toLong(v); + case DECIMAL: + return SqlFunctions.toBigDecimal(v); + case FLOAT: + case REAL: + return castToFloat(v); + case DOUBLE: + return castToDouble(v); + case CHAR: + case VARCHAR: + return castToString(v); + default: + throw new UnsupportedOperationException( + "CAST to " + target + " is not supported in Java filter evaluation yet."); + } + } + + private static Object unwrapForCast(final Object o) { + if (o instanceof NlsString) { + return ((NlsString) o).getValue(); + } + if (o instanceof Character) { + return o.toString(); + } + return o; + } + + private static float castToFloat(final Object v) { + if (v instanceof DateString) { + return (float) ((DateString) v).getMillisSinceEpoch(); + } + if (v instanceof Date) { + return (float) ((Date) v).getTime(); + } + if (v instanceof Calendar) { + return (float) ((Calendar) v).getTimeInMillis(); + } + return SqlFunctions.toFloat(v); + } + + private static double castToDouble(final Object v) { + if (v instanceof DateString) { + return (double) ((DateString) v).getMillisSinceEpoch(); + } + if (v instanceof Date) { + return (double) ((Date) v).getTime(); + } + if (v instanceof Calendar) { + return (double) ((Calendar) v).getTimeInMillis(); + } + return SqlFunctions.toDouble(v); + } + + private static String castToString(final Object v) { + if (v instanceof String) { + return (String) v; + } + if (v instanceof NlsString) { + return ((NlsString) v).getValue(); + } + if (v instanceof Boolean) { + return SqlFunctions.toString((Boolean) v); + } + if (v instanceof Float) { + return SqlFunctions.toString((Float) v); + } + if (v instanceof Double) { + return SqlFunctions.toString((Double) v); + } + if (v instanceof BigDecimal) { + return SqlFunctions.toString((BigDecimal) v); + } + if (v instanceof Number) { + return ((Number) v).toString(); + } + if (v instanceof DateString) { + return v.toString(); + } + if (v instanceof Character) { + return v.toString(); + } + if (v instanceof Date) { + return v.toString(); + } + if (v instanceof Calendar) { + return ((Calendar) v).getTime().toString(); + } + return String.valueOf(v); + } +} diff --git a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java index 88114774b..e5adfc6d6 100755 --- a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java +++ b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java @@ -172,6 +172,23 @@ void javaFilterWithCast() throws Exception { assertTrue(result.stream().allMatch(field -> field.getField(1).equals("test1"))); } + @Test + void javaFilterWithCastIntColumnToVarchar() throws Exception { + final SqlContext sqlContext = this.createSqlContext("/data/exampleInt.csv"); + final Tuple2, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext, + "SELECT * FROM fs.exampleInt WHERE CAST(NAMEB AS VARCHAR) = '1'"); + final Collection result = t.field0; + final WayangPlan wayangPlan = t.field1; + + PlanTraversal.upstream().traverse(wayangPlan.getSinks()).getTraversedNodes() + .forEach(node -> node.addTargetPlatform(Java.platform())); + + sqlContext.execute(wayangPlan); + + assertTrue(!result.isEmpty()); + assertTrue(result.stream().allMatch(field -> field.getField(1).equals(1))); + } + @Test void sqlApiSourceTest() throws Exception { final JavaTypeFactoryImpl typeFactory = new JavaTypeFactoryImpl(); diff --git a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/calcite/converter/functions/SqlRuntimeCastTest.java b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/calcite/converter/functions/SqlRuntimeCastTest.java new file mode 100644 index 000000000..0b531f574 --- /dev/null +++ b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/calcite/converter/functions/SqlRuntimeCastTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.api.sql.calcite.converter.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.math.BigDecimal; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.util.NlsString; +import org.junit.jupiter.api.Test; + +class SqlRuntimeCastTest { + + @Test + void castNullYieldsNull() { + assertNull(SqlRuntimeCast.castValue(null, SqlTypeName.INTEGER)); + } + + @Test + void castIntegerToVarchar() { + assertEquals("1", SqlRuntimeCast.castValue(1, SqlTypeName.VARCHAR)); + } + + @Test + void castStringToInteger() { + assertEquals(42, SqlRuntimeCast.castValue("42", SqlTypeName.INTEGER)); + } + + @Test + void castStringToDouble() { + assertEquals(1.5d, (Double) SqlRuntimeCast.castValue("1.5", SqlTypeName.DOUBLE), 1e-9); + } + + @Test + void castNlsStringToInteger() { + final NlsString nls = new NlsString("7", "UTF-8", null); + assertEquals(7, SqlRuntimeCast.castValue(nls, SqlTypeName.INTEGER)); + } + + @Test + void castStringToBoolean() { + assertTrue(SqlRuntimeCast.castValue("TRUE", SqlTypeName.BOOLEAN) instanceof Boolean); + assertEquals(true, SqlRuntimeCast.castValue("TRUE", SqlTypeName.BOOLEAN)); + } + + @Test + void castInvalidBooleanThrows() { + assertThrows(RuntimeException.class, () -> SqlRuntimeCast.castValue("maybe", SqlTypeName.BOOLEAN)); + } + + @Test + void castBigDecimalToVarcharUsesSqlFormat() { + final String s = SqlRuntimeCast.castValue(BigDecimal.valueOf(1, 1), SqlTypeName.VARCHAR).toString(); + assertTrue(s.contains("1")); + } + + @Test + void castToDateUnsupported() { + assertThrows(UnsupportedOperationException.class, + () -> SqlRuntimeCast.castValue("2020-01-01", SqlTypeName.DATE)); + } +}