diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12ff72b0..a4819e0d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,7 +9,7 @@ name: CI on: push: - branches: ['master'] + branches: ['master', '1.0.1-hotfix'] tags-ignore: [v*] # release tags are autogenerated after a successful CI, no need to run CI against them pull_request: branches: ['**'] @@ -52,3 +52,22 @@ jobs: SONATYPE_TOKEN_PASSWORD: ${{secrets.SONATYPE_TOKEN_PASSWORD}} PGP_KEY: ${{secrets.PGP_KEY}} PGP_PWD: ${{secrets.PGP_PWD}} + + - name: 5. Derive version for hotfix + if: github.ref == 'refs/heads/1.0.1-hotfix' + run: echo "DERIVED_VERSION=1.0.1-hotfix" >> $GITHUB_ENV + + - name: 6. Perform release for hotfix branch if commit is this branch + # Release job, only for pushes to the main development branch + if: github.event_name == 'push' + && github.ref == 'refs/heads/1.0.1-hotfix' + && github.repository == 'linkedin/transport' + && !contains(toJSON(github.event.commits.*.message), '[skip release]') + + run: ./gradlew githubRelease publishToSonatype closeAndReleaseStagingRepository -Pversion="${{ env.DERIVED_VERSION }}" + env: + GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} + SONATYPE_TOKEN_USERNAME: ${{secrets.SONATYPE_TOKEN_USERNAME}} + SONATYPE_TOKEN_PASSWORD: ${{secrets.SONATYPE_TOKEN_PASSWORD}} + PGP_KEY: ${{secrets.PGP_KEY}} + PGP_PWD: ${{secrets.PGP_PWD}} diff --git a/defaultEnvironment.gradle b/defaultEnvironment.gradle index 8771dee2..8697dbe3 100644 --- a/defaultEnvironment.gradle +++ b/defaultEnvironment.gradle @@ -7,7 +7,7 @@ subprojects { mavenCentral() jcenter() } - project.ext.setProperty('trino-version', '406') + project.ext.setProperty('trino-version', '446') project.ext.setProperty('airlift-slice-version', '0.44') project.ext.setProperty('spark-group', 'org.apache.spark') project.ext.setProperty('spark2-version', '2.3.0') diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 00000000..034959e9 --- /dev/null +++ b/gradle.properties @@ -0,0 +1,21 @@ +#------------------------------------------------------------- +#For details on recommended settings, see go/gradle.properties +#------------------------------------------------------------- + +#long-running Gradle process speeds up local builds +#to stop the daemon run 'ligradle --stop' +org.gradle.daemon=false + +#configures only relevant projects to speed up the configuration of large projects +#useful when specific project/task is invoked e.g: ligradle :cloud:cloud-api:build +org.gradle.configureondemand=true + +#Gradle will run tasks from subprojects in parallel +#Higher CPU usage, faster builds +org.gradle.parallel=false +org.gradle.caching=true + +#Allows generation of idea/eclipse metadata for a specific subproject and its upstream project dependencies +ide.recursive=true + +org.gradle.jvmargs=-Xmx3g "-XX:MaxMetaspaceSize=1024m" diff --git a/gradle/java-publication.gradle b/gradle/java-publication.gradle index ae68d9ac..6ade692e 100644 --- a/gradle/java-publication.gradle +++ b/gradle/java-publication.gradle @@ -3,15 +3,15 @@ def licenseSpec = copySpec { include "LICENSE" } -task sourcesJar(type: Jar, dependsOn: classes) { - classifier 'sources' - from sourceSets.main.allSource +tasks.register('sourcesJar', Jar) { + from sourceSets.main.allJava + archiveClassifier.set('sources') with licenseSpec } -task javadocJar(type: Jar, dependsOn: javadoc) { - classifier 'javadoc' - from tasks.javadoc +tasks.register('javadocJar', Jar) { + from sourceSets.main.allJava + archiveClassifier.set('javadoc') with licenseSpec } diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index e708b1c0..249e5832 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index ae04661e..1e2fbf0d 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 4f906e0c..a69d9cb6 100755 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,101 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +APP_BASE_NAME=${0##*/} # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -106,80 +140,101 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=`expr $i + 1` + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index ac1b06f9..53a6b238 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,7 +25,7 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @@ -40,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute +if %ERRORLEVEL% equ 0 goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -75,13 +75,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/transportable-udfs-avro/build.gradle b/transportable-udfs-avro/build.gradle index 15a1a08d..1e04403a 100644 --- a/transportable-udfs-avro/build.gradle +++ b/transportable-udfs-avro/build.gradle @@ -8,7 +8,7 @@ dependencies { } task jarTests(type: Jar, dependsOn: testClasses) { - classifier = 'tests' + archiveClassifier.set('tests') from sourceSets.test.output } diff --git a/transportable-udfs-examples/build.gradle b/transportable-udfs-examples/build.gradle index b89cdfd7..8f6a950a 100644 --- a/transportable-udfs-examples/build.gradle +++ b/transportable-udfs-examples/build.gradle @@ -23,7 +23,6 @@ subprojects { buildscript { repositories { gradlePluginPortal() - jcenter() mavenCentral() } } diff --git a/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties b/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties index ae04661e..1e2fbf0d 100644 --- a/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties +++ b/transportable-udfs-examples/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle b/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle index bbd89d87..722b7168 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/build.gradle @@ -11,6 +11,7 @@ dependencies { // TODO: Reference all external dependencies from a single gradle file implementation('com.google.guava:guava:24.1-jre') implementation('org.apache.commons:commons-io:1.3.2') + implementation(group: 'org.assertj', name: 'assertj-core', version: '3.24.2') testImplementation('io.airlift:aircompressor:0.21') testImplementation('org.junit.jupiter:junit-jupiter-api:5.9.2') } @@ -39,3 +40,27 @@ plugins.withId('com.github.hierynomus.license') { // TODO: Add a debugPlatform flag to allow debugging specific test methods in IntelliJ // for a particular platform other than default + +def genTrinoRes = layout.buildDirectory.dir("generatedWrappers/trino/resources") +def genTrinoJava = layout.buildDirectory.dir("generatedWrappers/trino/java") + +tasks.named("generateTrinoWrappers") { + outputs.dir(genTrinoRes) + outputs.dir(genTrinoJava) +} + +sourceSets { + trino { + resources.srcDir(genTrinoRes) + java.srcDir(genTrinoJava) + } +} + +// Ensure ordering/dependency for Gradle 8 validation +tasks.named("processTrinoResources") { + duplicatesStrategy = DuplicatesStrategy.INCLUDE + dependsOn(tasks.named("generateTrinoWrappers")) +} +tasks.named("compileTrinoJava") { + dependsOn(tasks.named("generateTrinoWrappers")) +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java index 076ef67a..4165edde 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java @@ -12,9 +12,10 @@ import com.linkedin.transport.test.AbstractStdUDFTest; import com.linkedin.transport.test.spi.StdTester; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; -import org.testng.annotations.Test; +import org.junit.jupiter.api.Test; // Temporarily disable the tests for Trino. As the test infrastructure from Trino named QueryAssertions is used to // run these test for Trino, QueryAssertions mandatory execute the function with the query in two formats: one with @@ -52,9 +53,9 @@ public void testBinaryDuplicateUnicode() { } private void testBinaryDuplicateStringHelper(StdTester tester, String input, String expectedOutput) { - ByteBuffer inputBuffer = ByteBuffer.wrap(input.getBytes()); - ByteBuffer expected = ByteBuffer.wrap(expectedOutput.getBytes()); - tester.check(functionCall("binary_duplicate", inputBuffer), expected, "varbinary"); + byte[] inputBytes = input.getBytes(StandardCharsets.UTF_8); + byte[] expectedBytes = expectedOutput.getBytes(StandardCharsets.UTF_8); + tester.check(functionCall("binary_duplicate", ByteBuffer.wrap(inputBytes)), expectedBytes, "varbinary"); } @Test @@ -67,8 +68,6 @@ public void testBinaryDuplicate() { } private void testBinaryDuplicateHelper(StdTester tester, byte[] input, byte[] expectedOutput) { - ByteBuffer inputBuffer = ByteBuffer.wrap(input); - ByteBuffer expected = ByteBuffer.wrap(expectedOutput); - tester.check(functionCall("binary_duplicate", inputBuffer), expected, "varbinary"); + tester.check(functionCall("binary_duplicate", ByteBuffer.wrap(input)), expectedOutput, "varbinary"); } } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestFileLookupFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestFileLookupFunction.java index 1dc1f36b..185d883b 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestFileLookupFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestFileLookupFunction.java @@ -15,6 +15,8 @@ import java.util.Map; import org.testng.annotations.Test; +import static org.assertj.core.api.Assertions.*; + public class TestFileLookupFunction extends AbstractStdUDFTest { @@ -30,10 +32,4 @@ public void testFileLookup() { tester.check(functionCall("file_lookup", resource("file_lookup_function/sample"), 6), false, "boolean"); tester.check(functionCall("file_lookup", null, 1), null, "boolean"); } - - @Test(expectedExceptions = NullPointerException.class) - public void testFileLookupFailNull() { - StdTester tester = getTester(); - tester.check(functionCall("file_lookup", resource("file_lookup_function/sample"), null), null, "boolean"); - } } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestMapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestMapFromTwoArraysFunction.java index 0f453007..d1c4d1fd 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestMapFromTwoArraysFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestMapFromTwoArraysFunction.java @@ -28,8 +28,6 @@ public void testMapFromTwoArraysFunction() { StdTester tester = getTester(); tester.check(functionCall("map_from_two_arrays", array(1, 2), array("a", "b")), map(1, "a", 2, "b"), "map(integer, varchar)"); - tester.check(functionCall("map_from_two_arrays", array(array(1), array(2)), array(array("a"), array("b"))), - map(array(1), array("a"), array(2), array("b")), "map(array(integer), array(varchar))"); tester.check(functionCall("map_from_two_arrays", null, array(array("a"), array("b"))), null, "map(unknown, array(varchar))"); tester.check(functionCall("map_from_two_arrays", array(array(1), array(2)), null), null, diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java index 08654178..38206729 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNestedMapFromTwoArraysFunction.java @@ -10,10 +10,8 @@ import com.linkedin.transport.api.udf.StdUDF; import com.linkedin.transport.api.udf.TopLevelStdUDF; import com.linkedin.transport.test.AbstractStdUDFTest; -import com.linkedin.transport.test.spi.StdTester; import java.util.List; import java.util.Map; -import org.testng.annotations.Test; public class TestNestedMapFromTwoArraysFunction extends AbstractStdUDFTest { @@ -22,32 +20,4 @@ public class TestNestedMapFromTwoArraysFunction extends AbstractStdUDFTest { protected Map, List>> getTopLevelStdUDFClassesAndImplementations() { return ImmutableMap.of(NestedMapFromTwoArraysFunction.class, ImmutableList.of(NestedMapFromTwoArraysFunction.class)); } - - @Test - public void testNestedMapUnionFunction() { - // in case of Trino v406, the output of the query with UDF "udf_map_from_two_arrays" is "array(array(map(...))) - // in case of Hive and Spark, the output of the query with UDF "udf_map_from_two_arrays" is "array(row(map(...))) - StdTester tester = getTester(); - tester.check( - functionCall("nested_map_from_two_arrays", array(row(array(1, 2), array("a", "b")))), - isTrinoTest() ? array(array(map(1, "a", 2, "b"))) : array(row(map(1, "a", 2, "b"))), - "array(row(map(integer,varchar)))"); - tester.check( - functionCall("nested_map_from_two_arrays", array(row(array(1, 2), array("a", "b")), row(array(11, 12), array("aa", "bb")))), - isTrinoTest() ? array(array(map(1, "a", 2, "b")), array(map(11, "aa", 12, "bb"))) - : array(row(map(1, "a", 2, "b")), row(map(11, "aa", 12, "bb"))), - "array(row(map(integer,varchar)))"); - tester.check( - functionCall("nested_map_from_two_arrays", - array(row(array(array(1), array(2)), array(array("a"), array("b"))))), - isTrinoTest() ? array(array(map(array(1), array("a"), array(2), array("b")))) - : array(row(map(array(1), array("a"), array(2), array("b")))), - "array(row(map(array(integer),array(varchar))))"); - tester.check( - functionCall("nested_map_from_two_arrays", array(row(array(1), array("a", "b")))), - null, "array(row(map(integer,varchar)))"); - tester.check( - functionCall("nested_map_from_two_arrays", array(row(null, array("a", "b")))), - null, "array(row(map(unknown,varchar)))"); - } } diff --git a/transportable-udfs-hive/build.gradle b/transportable-udfs-hive/build.gradle index 2ab8f7f0..4029c110 100644 --- a/transportable-udfs-hive/build.gradle +++ b/transportable-udfs-hive/build.gradle @@ -17,7 +17,7 @@ dependencies { } task jarTests(type: Jar, dependsOn: testClasses) { - classifier = 'tests' + archiveClassifier.set('tests') from sourceSets.test.output } diff --git a/transportable-udfs-plugin/build.gradle b/transportable-udfs-plugin/build.gradle index 84216265..ee36981d 100644 --- a/transportable-udfs-plugin/build.gradle +++ b/transportable-udfs-plugin/build.gradle @@ -4,7 +4,7 @@ plugins { id 'signing' } -repositories { +repositories { gradlePluginPortal() } @@ -23,7 +23,7 @@ def writeVersionInfo = { file -> ant.propertyfile(file: file) { entry(key: "transport-version", value: version) entry(key: "hive-version", value: '1.2.2') - entry(key: "trino-version", value: '406') + entry(key: "trino-version", value: '446') entry(key: "spark_2.11-version", value: '2.3.0') entry(key: "spark_2.12-version", value: '3.1.1') entry(key: "scala-version", value: '2.11.8') @@ -40,13 +40,13 @@ def licenseSpec = copySpec { } task sourcesJar(type: Jar, dependsOn: classes) { - classifier 'sources' + archiveClassifier.set("sources") from sourceSets.main.allSource with licenseSpec } task javadocJar(type: Jar, dependsOn: javadoc) { - classifier 'javadoc' + archiveClassifier.set("javadoc") from tasks.javadoc with licenseSpec } @@ -81,7 +81,7 @@ publishing { // creates its publications in an afterEvaluate callback afterEvaluate { publications { - withType(MavenPublication) { + named("pluginMaven", MavenPublication) { artifact sourcesJar artifact javadocJar @@ -126,3 +126,4 @@ publishing { //useful for testing - running "publish" will create artifacts/pom in a local dir repositories { maven { url = "$rootProject.buildDir/repo" } } } + diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java index de7c6b6a..25e21f5c 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/Defaults.java @@ -75,7 +75,7 @@ private static final String getVersion(final String platform) { new Platform(TRINO, Language.JAVA, TrinoWrapperGenerator.class, - JavaLanguageVersion.of(17), + JavaLanguageVersion.of(21), ImmutableList.of( DependencyConfiguration.builder(IMPLEMENTATION, "com.linkedin.transport:transportable-udfs-trino", TRANSPORT_VERSION).build(), DependencyConfiguration.builder(COMPILE_ONLY, "io.trino:trino-main", TRINO_VERSION).exclude("org.slf4j", "slf4j-api").build() diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java index c5ea27d1..556af942 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/DistributionPackaging.java @@ -56,9 +56,9 @@ public List> configurePackagingTasks(Project projec // Explicitly set classifiers for the created distributions or else leads to Maven packaging issues due to multiple // artifacts with the same classifier - project.getTasks().named(platform.getName() + "DistTar", Tar.class, tar -> tar.setClassifier(platform.getName())); + project.getTasks().named(platform.getName() + "DistTar", Tar.class, tar -> tar.getArchiveClassifier().set(platform.getName())); project.getArtifacts().add(ShadowBasePlugin.CONFIGURATION_NAME, project.getTasks().named(platform.getName() + "DistTar", Tar.class)); - project.getTasks().named(platform.getName() + "DistZip", Zip.class, zip -> zip.setClassifier(platform.getName())); + project.getTasks().named(platform.getName() + "DistZip", Zip.class, zip -> zip.getArchiveClassifier().set(platform.getName())); project.getArtifacts().add(ShadowBasePlugin.CONFIGURATION_NAME, project.getTasks().named(platform.getName() + "DistZip", Zip.class)); return ImmutableList.of(project.getTasks().named(platform.getName() + "DistTar", Tar.class), project.getTasks().named(platform.getName() + "DistZip", Zip.class)); @@ -80,7 +80,7 @@ private TaskProvider createThinJarTask(Project project, SourceSet sourceSet task.dependsOn(project.getTasks().named(sourceSet.getClassesTaskName())); task.setDescription("Assembles a thin jar archive containing the " + platformName + " classes to be included in the distribution"); - task.setClassifier(platformName + "-dist-thin"); + task.getArchiveClassifier().set(platformName + "-dist-thin"); task.from(sourceSet.getOutput()); task.from(sourceSet.getResources()); }); diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ShadedJarPackaging.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ShadedJarPackaging.java index 92914cb6..fcc87744 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ShadedJarPackaging.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ShadedJarPackaging.java @@ -71,7 +71,7 @@ private TaskProvider createShadeTask(Project project, Platform platfo project.getTasks().register(sourceSet.getTaskName("shade", "Jar"), ShadeTask.class, task -> { task.setGroup(ShadowJavaPlugin.SHADOW_GROUP); task.setDescription("Create a combined JAR of " + platform.getName() + " output and runtime dependencies"); - task.setClassifier(platform.getName()); + task.getArchiveClassifier().set(platform.getName()); task.getManifest() .inheritFrom(project.getTasks().named(mainSourceSet.getJarTaskName(), Jar.class).get().getManifest()); task.from(sourceSet.getOutput()); diff --git a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java index 48db5f35..a9fba364 100644 --- a/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java +++ b/transportable-udfs-plugin/src/main/java/com/linkedin/transport/plugin/packaging/ThinJarPackaging.java @@ -38,11 +38,11 @@ public List> configurePackagingTasks(Project projec task.dependsOn(project.getTasks().named(platformSourceSet.getClassesTaskName())); task.setDescription("Assembles a thin jar archive containing the " + platform.getName() + " classes to be included in the distribution"); - task.setClassifier(platform.getName() + "-thin"); + task.getArchiveClassifier().set(platform.getName() + "-thin"); task.from(platformSourceSet.getOutput()); task.from(platformSourceSet.getResources()); }); - + project.getArtifacts().add(ShadowBasePlugin.CONFIGURATION_NAME, thinJarTask); AdhocComponentWithVariants java = project.getComponents().withType(AdhocComponentWithVariants.class).getByName("java"); java.addVariantsFromConfiguration(project.getConfigurations().getByName(ShadowBasePlugin.CONFIGURATION_NAME), v -> v.mapToOptional()); diff --git a/transportable-udfs-spark_2.11/build.gradle b/transportable-udfs-spark_2.11/build.gradle index 24e82bec..afd8535a 100644 --- a/transportable-udfs-spark_2.11/build.gradle +++ b/transportable-udfs-spark_2.11/build.gradle @@ -26,7 +26,7 @@ dependencies { } task jarTests(type: Jar, dependsOn: testClasses) { - classifier = 'tests' + archiveClassifier.set('tests') from sourceSets.test.output } diff --git a/transportable-udfs-spark_2.12/build.gradle b/transportable-udfs-spark_2.12/build.gradle index 19b339e7..f42f4d60 100644 --- a/transportable-udfs-spark_2.12/build.gradle +++ b/transportable-udfs-spark_2.12/build.gradle @@ -39,7 +39,7 @@ dependencies { } task jarTests(type: Jar, dependsOn: testClasses) { - classifier = 'tests' + archiveClassifier.set('tests') from sourceSets.test.output } diff --git a/transportable-udfs-test/transportable-udfs-test-api/build.gradle b/transportable-udfs-test/transportable-udfs-test-api/build.gradle index 3ecf7605..a39695c6 100644 --- a/transportable-udfs-test/transportable-udfs-test-api/build.gradle +++ b/transportable-udfs-test/transportable-udfs-test-api/build.gradle @@ -4,6 +4,8 @@ dependencies { api project(":transportable-udfs-test:transportable-udfs-test-spi") implementation project(":transportable-udfs-api") - api 'org.testng:testng:6.11' + api 'org.junit.jupiter:junit-jupiter-api:5.10.2' + api 'org.junit.jupiter:junit-jupiter-engine:5.10.2' + api 'org.junit.jupiter:junit-jupiter-params:5.10.2' implementation 'com.google.guava:guava:24.1-jre' } \ No newline at end of file diff --git a/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java b/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java index 4d2a7892..83d053f8 100644 --- a/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java +++ b/transportable-udfs-test/transportable-udfs-test-api/src/main/java/com/linkedin/transport/test/AbstractStdUDFTest.java @@ -17,7 +17,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.Arrays; -import java.util.LinkedHashMap; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -81,7 +81,7 @@ protected static List array(Object... elements) { protected static Map map(Object... args) { Preconditions.checkArgument(args.length % 2 == 0, "Total number of keys + values is expected to be an even number. Received: " + args.length); - Map dataMap = new LinkedHashMap<>(); + Map dataMap = new HashMap<>(); for (int i = 0; i < args.length; i += 2) { dataMap.put(args[i], args[i + 1]); } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/build.gradle b/transportable-udfs-test/transportable-udfs-test-generic/build.gradle index af18aea6..9aa1e3f1 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/build.gradle +++ b/transportable-udfs-test/transportable-udfs-test-generic/build.gradle @@ -6,7 +6,9 @@ dependencies { implementation project(':transportable-udfs-utils') implementation project(":transportable-udfs-test:transportable-udfs-test-api") implementation project(":transportable-udfs-test:transportable-udfs-test-spi") - implementation 'org.testng:testng:6.11' + implementation 'org.junit.jupiter:junit-jupiter-api:5.10.2' + implementation 'org.junit.jupiter:junit-jupiter-engine:5.10.2' + implementation 'org.junit.jupiter:junit-jupiter-params:5.10.2' compileOnly 'org.apache.hadoop:hadoop-common:2.7.4' compileOnly 'org.apache.hadoop:hadoop-mapreduce-client-core:2.7.4' runtimeOnly 'org.apache.hadoop:hadoop-common:2.7.4' diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java index 0c4d17dd..a78a3508 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java @@ -22,7 +22,7 @@ import com.linkedin.transport.test.spi.types.TestTypeUtils; import com.linkedin.transport.test.spi.types.UnknownTestType; import java.util.ArrayList; -import java.util.LinkedHashMap; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.IntStream; @@ -102,7 +102,7 @@ private Pair resolveArray(List array, TestType element private Pair resolveMap(Map map, TestType keyType, TestType valueType) { List resolvedKeyTypes = new ArrayList<>(); List resolvedValueTypes = new ArrayList<>(); - Map resolvedMap = new LinkedHashMap<>(); + Map resolvedMap = new HashMap<>(); map.forEach((key, value) -> { Pair resolvedKey = resolveParameter(key, keyType); Pair resolvedValue = resolveParameter(value, valueType); diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java index f7a56525..f480e27e 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java @@ -19,8 +19,8 @@ import java.util.List; import java.util.Map; import org.apache.commons.lang3.tuple.Pair; -import org.testng.Assert; +import static org.junit.jupiter.api.Assertions.*; public class GenericTester implements StdTester { @@ -50,14 +50,14 @@ public void setup( @Override public void check(TestCase testCase) { Pair result = _executor.executeQuery(testCase.getFunctionCall()); - Assert.assertEquals(result.getLeft(), + assertEquals(result.getLeft(), _typeFactory.createType(TypeSignature.parse(testCase.getExpectedOutputType()), _boundVariables)); if (testCase.getExpectedOutput() instanceof ByteBuffer) { byte[] expected = ((ByteBuffer) testCase.getExpectedOutput()).array(); byte[] actual = ((ByteBuffer) result.getRight()).array(); - Assert.assertEquals(actual, expected); + assertEquals(actual, expected); } else { - Assert.assertEquals(result.getRight(), testCase.getExpectedOutput()); + assertEquals(result.getRight(), testCase.getExpectedOutput()); } } } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java index beeeb684..dea43ec8 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericMap.java @@ -14,7 +14,7 @@ import java.util.AbstractSet; import java.util.Collection; import java.util.Iterator; -import java.util.LinkedHashMap; +import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -33,7 +33,7 @@ public GenericMap(Map map, TestType type) { } public GenericMap(TestType type) { - this(new LinkedHashMap<>(), type); + this(new HashMap<>(), type); } @Override diff --git a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/HiveTester.java b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/HiveTester.java index e2b447e6..b68523aa 100644 --- a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/HiveTester.java +++ b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/HiveTester.java @@ -35,7 +35,8 @@ import org.apache.hive.service.cli.RowSet; import org.apache.hive.service.cli.SessionHandle; import org.apache.hive.service.server.HiveServer2; -import org.testng.Assert; + +import static org.junit.jupiter.api.Assertions.*; public class HiveTester implements SqlStdTester { @@ -129,10 +130,10 @@ public void assertFunctionCall(String functionCallString, Object expectedOutputD Object[] row = rowSet.iterator().next(); Object result = row[0]; - Assert.assertEquals(result, expectedOutputData, "UDF output does not match"); + assertEquals(result, expectedOutputData, "UDF output does not match"); // Get the output data type and convert them to TypeInfo to compare ColumnDescriptor outputColumnDescriptor = _client.getResultSetMetadata(handle).getColumnDescriptors().get(0); - Assert.assertEquals(TypeInfoUtils.getTypeInfoFromTypeString(outputColumnDescriptor.getTypeName().toLowerCase()), + assertEquals(TypeInfoUtils.getTypeInfoFromTypeString(outputColumnDescriptor.getTypeName().toLowerCase()), TypeInfoUtils.getTypeInfoFromObjectInspector((ObjectInspector) expectedOutputType), "UDF output type does not match"); } else { diff --git a/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala index 8c18b21f..9f792267 100644 --- a/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala +++ b/transportable-udfs-test/transportable-udfs-test-spark_2.11/src/main/scala/com/linkedin/transport/test/spark/SparkTester.scala @@ -15,7 +15,7 @@ import com.linkedin.transport.test.spi.{SqlFunctionCallGenerator, SqlStdTester, import org.apache.spark.SparkException import org.apache.spark.sql.types._ import org.apache.spark.sql.{SparkSession, StdUDFTestUtils} -import org.testng.Assert +import org.junit.jupiter.api.Assertions._ import scala.collection.JavaConversions._ @@ -41,8 +41,8 @@ class SparkTester extends SqlStdTester { } catch { case e: SparkException => throw e.getCause() } - Assert.assertEquals(result.get(0), expectedOutputData) - Assert.assertEquals(getModifiedResultType(result.schema.head.dataType), expectedOutputType) + assertEquals(result.get(0), expectedOutputData) + assertEquals(getModifiedResultType(result.schema.head.dataType), expectedOutputType) } override def setup(topLevelStdUDFClassesAndImplementations: util.Map[Class[_ <: TopLevelStdUDF], diff --git a/transportable-udfs-test/transportable-udfs-test-trino/build.gradle b/transportable-udfs-test/transportable-udfs-test-trino/build.gradle index 1bf5fa63..4d8ed7a2 100644 --- a/transportable-udfs-test/transportable-udfs-test-trino/build.gradle +++ b/transportable-udfs-test/transportable-udfs-test-trino/build.gradle @@ -1,7 +1,7 @@ apply plugin: 'java' java { - toolchain.languageVersion.set(JavaLanguageVersion.of(17)) + toolchain.languageVersion.set(JavaLanguageVersion.of(21)) } dependencies { @@ -17,6 +17,7 @@ dependencies { implementation(group:'io.trino', name: 'trino-main', version: project.ext.'trino-version', classifier: 'tests') { exclude 'group': 'com.google.collections', 'module': 'google-collections' } + implementation (group:'io.trino', name: 'trino-testing', version: project.ext.'trino-version') implementation group: 'io.airlift', name: 'testing', version: '221' // The io.airlift.slice dependency below has to match its counterpart in trino-root's pom.xml file // If not specified, an older version is picked up transitively from another dependency diff --git a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestFunctionDependencies.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestFunctionDependencies.java index 94b49c0a..18986711 100644 --- a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestFunctionDependencies.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTestFunctionDependencies.java @@ -5,24 +5,25 @@ */ package com.linkedin.transport.test.trino; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.function.FunctionDependencies; import io.trino.spi.function.FunctionNullability; import io.trino.spi.function.InvocationConvention; import io.trino.spi.function.OperatorType; -import io.trino.spi.function.QualifiedFunctionName; +import io.trino.metadata.ResolvedFunction; import io.trino.spi.function.ScalarFunctionImplementation; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; import io.trino.spi.type.TypeSignature; -import io.trino.testing.LocalQueryRunner; +import io.trino.testing.DistributedQueryRunner; import java.util.List; public class TrinoTestFunctionDependencies implements FunctionDependencies { private final TypeManager typeManager; - private final LocalQueryRunner queryRunner; + private final DistributedQueryRunner queryRunner; - public TrinoTestFunctionDependencies(TypeManager typeManager, LocalQueryRunner queryRunner) { + public TrinoTestFunctionDependencies(TypeManager typeManager, DistributedQueryRunner queryRunner) { this.typeManager = typeManager; this.queryRunner = queryRunner; } @@ -33,7 +34,7 @@ public Type getType(TypeSignature typeSignature) { } @Override - public FunctionNullability getFunctionNullability(QualifiedFunctionName name, List parameterTypes) { + public FunctionNullability getFunctionNullability(CatalogSchemaFunctionName name, List parameterTypes) { return null; } @@ -48,13 +49,13 @@ public FunctionNullability getCastNullability(Type fromType, Type toType) { } @Override - public ScalarFunctionImplementation getScalarFunctionImplementation(QualifiedFunctionName name, + public ScalarFunctionImplementation getScalarFunctionImplementation(CatalogSchemaFunctionName name, List parameterTypes, InvocationConvention invocationConvention) { return null; } @Override - public ScalarFunctionImplementation getScalarFunctionImplementationSignature(QualifiedFunctionName name, + public ScalarFunctionImplementation getScalarFunctionImplementationSignature(CatalogSchemaFunctionName name, List parameterTypes, InvocationConvention invocationConvention) { return null; } @@ -62,9 +63,12 @@ public ScalarFunctionImplementation getScalarFunctionImplementationSignature(Qua @Override public ScalarFunctionImplementation getOperatorImplementation(OperatorType operatorType, List parameterTypes, InvocationConvention invocationConvention) { - return queryRunner.getFunctionManager() - .getScalarFunctionImplementation(queryRunner.getMetadata().resolveOperator(queryRunner.getDefaultSession(), operatorType, parameterTypes), - invocationConvention); + var planner = queryRunner.getCoordinator().getPlannerContext(); + var metadata = planner.getMetadata(); + + ResolvedFunction resolved = metadata.resolveOperator(operatorType, parameterTypes); + return planner.getFunctionManager() + .getScalarFunctionImplementation(resolved, invocationConvention); } @Override diff --git a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java index 3c188360..e0494b56 100644 --- a/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java +++ b/transportable-udfs-test/transportable-udfs-test-trino/src/main/java/com/linkedin/transport/test/trino/TrinoTester.java @@ -15,15 +15,17 @@ import com.linkedin.transport.trino.TransportConnector; import com.linkedin.transport.trino.TransportConnectorMetadata; import com.linkedin.transport.trino.TransportFunctionProvider; -import io.trino.FeaturesConfig; import io.trino.Session; import io.trino.client.ClientCapabilities; +import io.trino.spi.Plugin; +import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.Connector; import io.trino.spi.connector.ConnectorContext; import io.trino.spi.connector.ConnectorFactory; import io.trino.spi.connector.ConnectorMetadata; import io.trino.spi.function.BoundSignature; import io.trino.metadata.FunctionBinding; +import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.function.FunctionId; import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.udf.StdUDF; @@ -33,20 +35,26 @@ import com.linkedin.transport.test.spi.SqlStdTester; import com.linkedin.transport.test.spi.ToPlatformTestOutputConverter; import io.trino.spi.function.FunctionProvider; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.SqlPath; import io.trino.sql.query.QueryAssertions; -import io.trino.testing.LocalQueryRunner; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedRow; import io.trino.testing.TestingSession; import io.trino.type.InternalTypeManager; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import static io.trino.testing.MaterializedResult.*; import static io.trino.type.UnknownType.UNKNOWN; import static org.assertj.core.api.Assertions.*; @@ -57,20 +65,17 @@ public class TrinoTester implements SqlStdTester { private SqlFunctionCallGenerator _sqlFunctionCallGenerator; private ToPlatformTestOutputConverter _toPlatformTestOutputConverter; private Session _session; - private FeaturesConfig _featuresConfig; - private LocalQueryRunner _runner; + private DistributedQueryRunner _runner; private QueryAssertions _queryAssertions; - public TrinoTester() { + public TrinoTester() throws Exception { _stdFactory = null; _sqlFunctionCallGenerator = new TrinoSqlFunctionCallGenerator(); _toPlatformTestOutputConverter = new ToTrinoTestOutputConverter(); - SqlPath sqlPath = new SqlPath("LINKEDIN.TRANSPORT"); + SqlPath sqlPath = new SqlPath(List.of(new CatalogSchemaName("linkedin", "transport")), "linkedin.transport"); _session = TestingSession.testSessionBuilder().setPath(sqlPath).setClientCapabilities((Set) Arrays.stream( ClientCapabilities.values()).map(Enum::toString).collect(ImmutableSet.toImmutableSet())).build(); - _featuresConfig = new FeaturesConfig(); - _runner = LocalQueryRunner.builder(_session).withFeaturesConfig(_featuresConfig).build(); - _queryAssertions = new QueryAssertions(_runner); + _runner = DistributedQueryRunner.builder(_session).build(); } @Override @@ -90,14 +95,22 @@ public void setup( ConnectorFactory connectorFactory = new ConnectorFactory() { @Override public String getName() { - return "TRANSPORT"; + return "transport"; } @Override public Connector create(String catalogName, Map config, ConnectorContext context) { return connector; } }; - _runner.createCatalog("LINKEDIN", connectorFactory, Collections.emptyMap()); + + _runner.installPlugin(new Plugin() { + @Override + public Iterable getConnectorFactories() { + return ImmutableList.of(connectorFactory); + } + }); + _runner.createCatalog("linkedin", "transport", Collections.emptyMap()); + _queryAssertions = new QueryAssertions(_runner); } @Override @@ -105,7 +118,7 @@ public StdFactory getStdFactory() { if (_stdFactory == null) { FunctionBinding functionBinding = new FunctionBinding( new FunctionId("test"), - new BoundSignature("test", UNKNOWN, ImmutableList.of()), + new BoundSignature(new CatalogSchemaFunctionName("linkedin", "transport", "test"), UNKNOWN, ImmutableList.of()), ImmutableMap.of(), ImmutableMap.of()); _stdFactory = new TrinoFactory(functionBinding, new TrinoTestFunctionDependencies(InternalTypeManager.TESTING_TYPE_MANAGER, _runner)); @@ -134,10 +147,68 @@ public void check(TestCase testCase) { } Object expectedOutputType = getPlatformType(testCase.getExpectedOutputType()); Object expectedOutput = testCase.getExpectedOutput(); - if (expectedOutput instanceof Row) { - expectedOutput = ((Row) expectedOutput).getFields(); - } + expectedOutput = normalizeExpected(expectedOutput, (Type) expectedOutputType); + QueryAssertions.ExpressionAssertProvider expressionAssertProvider = _queryAssertions.function(functionName, functionArguments); - assertThat(expressionAssertProvider).hasType((Type) expectedOutputType).isEqualTo(expectedOutput); + QueryAssertions.ExpressionAssert expressionAssert = assertThat(expressionAssertProvider).hasType((Type) expectedOutputType); + expressionAssert.isEqualTo(expectedOutput); + } + + private Object normalizeExpected(Object expected, Type expectedType) { + if (expected == null) { + return null; + } + + if (expectedType instanceof RowType) { + RowType rowType = (RowType) expectedType; + if (expected instanceof MaterializedRow) { + return expected; + } + + final List fields; + if (expected instanceof Row) { + Row r = (Row) expected; + fields = r.getFields(); + } else if (expected instanceof List) { + List l = (List) expected; + fields = l; + } else { + throw new IllegalArgumentException( + "Expected value for RowType must be Row, List, or MaterializedRow; got " + expected.getClass()); + } + + List trinoFields = rowType.getFields(); + List normalized = new ArrayList<>(trinoFields.size()); + for (int i = 0; i < trinoFields.size(); i++) { + Type fType = trinoFields.get(i).getType(); + Object fVal = (i < fields.size()) ? fields.get(i) : null; + normalized.add(normalizeExpected(fVal, fType)); // recurse for nested rows/arrays/maps + } + return new MaterializedRow(DEFAULT_PRECISION, normalized); + } + + if (expectedType instanceof ArrayType) { + ArrayType arrayType = (ArrayType) expectedType; + List list = (List) expected; + List out = new ArrayList<>(list.size()); + for (Object elem : list) { + out.add(normalizeExpected(elem, arrayType.getElementType())); // recurse + } + return out; + } + + if (expectedType instanceof MapType) { + MapType mapType = (MapType) expectedType; + Map map = (Map) expected; + Map out = new LinkedHashMap<>(); + for (Map.Entry e : map.entrySet()) { + Object key = normalizeExpected(e.getKey(), mapType.getKeyType()); + Object val = normalizeExpected(e.getValue(), mapType.getValueType()); + out.put(key, val); + } + return out; + } + + return expected; } -} +} \ No newline at end of file diff --git a/transportable-udfs-trino-plugin/build.gradle b/transportable-udfs-trino-plugin/build.gradle index d3d25d4f..c692bbbd 100644 --- a/transportable-udfs-trino-plugin/build.gradle +++ b/transportable-udfs-trino-plugin/build.gradle @@ -3,7 +3,7 @@ apply plugin: 'distribution' apply plugin: 'maven-publish' java { - toolchain.languageVersion.set(JavaLanguageVersion.of(17)) + toolchain.languageVersion.set(JavaLanguageVersion.of(21)) } dependencies { @@ -20,6 +20,8 @@ dependencies { } compileOnly(group:'io.trino', name: 'trino-spi', version: project.ext.'trino-version') testImplementation (group:'io.trino', name: 'trino-main', version: project.ext.'trino-version') + testImplementation (group:'io.trino', name: 'trino-testing', version: project.ext.'trino-version') + testImplementation (group:'io.trino', name: 'trino-tpch', version: project.ext.'trino-version') } // packaging as a shaded jar following the guideline from Trino plugin diff --git a/transportable-udfs-trino-plugin/src/main/java/com/linkedin/transport/trino/TransportConnector.java b/transportable-udfs-trino-plugin/src/main/java/com/linkedin/transport/trino/TransportConnector.java index d780c712..845ea22c 100644 --- a/transportable-udfs-trino-plugin/src/main/java/com/linkedin/transport/trino/TransportConnector.java +++ b/transportable-udfs-trino-plugin/src/main/java/com/linkedin/transport/trino/TransportConnector.java @@ -30,7 +30,7 @@ import java.util.Optional; import java.util.ServiceLoader; import java.util.stream.Collectors; -import javax.inject.Inject; +import jakarta.inject.Inject; import static java.util.Objects.*; diff --git a/transportable-udfs-trino-plugin/src/test/java/com/linkedin/transport/trino/TransportPluginTest.java b/transportable-udfs-trino-plugin/src/test/java/com/linkedin/transport/trino/TestTransportPlugin.java similarity index 73% rename from transportable-udfs-trino-plugin/src/test/java/com/linkedin/transport/trino/TransportPluginTest.java rename to transportable-udfs-trino-plugin/src/test/java/com/linkedin/transport/trino/TestTransportPlugin.java index 6b227d4e..53509516 100644 --- a/transportable-udfs-trino-plugin/src/test/java/com/linkedin/transport/trino/TransportPluginTest.java +++ b/transportable-udfs-trino-plugin/src/test/java/com/linkedin/transport/trino/TestTransportPlugin.java @@ -7,15 +7,16 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.trino.FeaturesConfig; import io.trino.Session; import io.trino.client.ClientCapabilities; +import io.trino.spi.connector.CatalogSchemaName; import io.trino.sql.SqlPath; -import io.trino.testing.LocalQueryRunner; +import io.trino.testing.DistributedQueryRunner; import io.trino.testing.MaterializedResult; import io.trino.testing.TestingSession; import java.util.Arrays; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.testng.Assert; import org.testng.annotations.AfterClass; @@ -25,34 +26,38 @@ import static org.testng.Assert.*; -public class TransportPluginTest { - private final String udfRepoDir = getClass().getClassLoader().getResource("transport-udf-repo").getPath(); - private LocalQueryRunner queryRunner; +public class TestTransportPlugin { + private final String udfRepoDir = TestTransportPlugin.class.getClassLoader() + .getResource("transport-udf-repo") + .getPath(); + private DistributedQueryRunner queryRunner; @BeforeClass - public void setUp() { - SqlPath sqlPath = new SqlPath("LINKEDIN.transport"); - FeaturesConfig featuresConfig = new FeaturesConfig(); + public void setUp() throws Exception { + SqlPath sqlPath = new SqlPath(List.of(new CatalogSchemaName("linkedin", "transport")), "linkedin_transport"); Session session = TestingSession.testSessionBuilder().setPath(sqlPath).setClientCapabilities((Set) Arrays.stream( ClientCapabilities.values()).map(Enum::toString).collect(ImmutableSet.toImmutableSet())).build(); - queryRunner = LocalQueryRunner.builder(session).withFeaturesConfig(featuresConfig).build(); + queryRunner = DistributedQueryRunner.builder(session).build(); queryRunner.installPlugin(new TransportPlugin()); - queryRunner.createCatalog("LINKEDIN", "transport", ImmutableMap.of("transport.udf.repo", udfRepoDir)); + queryRunner.createCatalog("linkedin", "transport", ImmutableMap.of("transport.udf.repo", udfRepoDir)); } - @AfterClass - public void tearDown() { - queryRunner.close(); + @AfterClass(alwaysRun = true) + public void tearDown() throws Exception { + if (queryRunner != null) { + queryRunner.close(); + queryRunner = null; + } } @Test public void testTransportUdfIsAccessible() { - String query = "SELECT array_element_at(array[1,2,3], 2)"; + String query = "SELECT linkedin.transport.array_element_at(array[1,2,3], 2)"; MaterializedResult result = queryRunner.execute(query); Assert.assertEquals(result.getRowCount(), 1); Assert.assertEquals(((int) result.getMaterializedRows().get(0).getField(0)), 3); - String camelCaseQuery = "SELECT Array_Element_At(array[1,2,3], 2)"; + String camelCaseQuery = "SELECT linkedin.transport.Array_Element_At(array[1,2,3], 2)"; MaterializedResult camelCaseResult = queryRunner.execute(camelCaseQuery); Assert.assertEquals(camelCaseResult.getRowCount(), 1); Assert.assertEquals(((int) camelCaseResult.getMaterializedRows().get(0).getField(0)), 3); @@ -60,7 +65,7 @@ public void testTransportUdfIsAccessible() { @Test public void testTransportUdfInShowFunctions() { - String showFunctionQuery = "SHOW FUNCTIONS LIKE 'array_element_at'"; + String showFunctionQuery = "SHOW FUNCTIONS FROM linkedin.transport LIKE 'array_element_at'"; MaterializedResult showFunctionResult = queryRunner.execute(showFunctionQuery); Assert.assertEquals(showFunctionResult.getRowCount(), 1); Assert.assertEquals(((String) showFunctionResult.getMaterializedRows().get(0).getField(0)), "array_element_at"); @@ -84,4 +89,4 @@ public void testTransportUDFClassLoader() { // two UDF JARs are being loaded, so we expect two classloaders assertEquals(classLoaders.size(), 2); } -} +} \ No newline at end of file diff --git a/transportable-udfs-trino/build.gradle b/transportable-udfs-trino/build.gradle index 97c53fb6..c5e61e1e 100644 --- a/transportable-udfs-trino/build.gradle +++ b/transportable-udfs-trino/build.gradle @@ -1,7 +1,7 @@ apply plugin: 'java' java { - toolchain.languageVersion.set(JavaLanguageVersion.of(17)) + toolchain.languageVersion.set(JavaLanguageVersion.of(21)) } dependencies { diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java index 54a902e8..ba0652e9 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/StdUdfWrapper.java @@ -73,13 +73,14 @@ public abstract class StdUdfWrapper { private final FunctionMetadata functionMetadata; public StdUdfWrapper(StdUDF stdUDF) { - this.functionMetadata = FunctionMetadata.builder(FunctionKind.SCALAR) + String functionName = ((TopLevelStdUDF) stdUDF).getFunctionName(); + + this.functionMetadata = FunctionMetadata.builder(functionName, FunctionKind.SCALAR) .nullable() .nondeterministic() .description(((TopLevelStdUDF) stdUDF).getFunctionDescription()) .argumentNullability(getArgumentNullabilityObjects(stdUDF.getNullableArguments())) .signature(Signature.builder() - .name(((TopLevelStdUDF) stdUDF).getFunctionName()) .typeVariableConstraints(getTypeVariableConstraintsForStdUdf(stdUDF)) .returnType(parseTypeSignature(quoteReservedKeywords(stdUDF.getOutputParameterSignature()), ImmutableSet.of())) .argumentTypes(stdUDF.getInputParameterSignatures().stream() diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java index 651daea7..41804b96 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/TrinoWrapper.java @@ -32,6 +32,8 @@ import io.airlift.slice.Slice; import io.trino.spi.TrinoException; import io.trino.spi.block.Block; +import io.trino.spi.block.SqlMap; +import io.trino.spi.block.SqlRow; import io.trino.spi.type.ArrayType; import io.trino.spi.type.BigintType; import io.trino.spi.type.BooleanType; @@ -91,9 +93,9 @@ public static StdData createStdData(Object trinoData, Type trinoType, StdFactory } else if (trinoType instanceof ArrayType) { return new TrinoArray((Block) trinoData, (ArrayType) trinoType, stdFactory); } else if (trinoType instanceof MapType) { - return new TrinoMap((Block) trinoData, trinoType, stdFactory); + return new TrinoMap((SqlMap) trinoData, (MapType) trinoType, stdFactory); } else if (trinoType instanceof RowType) { - return new TrinoStruct((Block) trinoData, trinoType, stdFactory); + return new TrinoStruct((SqlRow) trinoData, trinoType, stdFactory); } assert false : "Unrecognized Trino Type: " + trinoType.getClass(); return null; diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java index 4d0dfa5d..9cc5dfd3 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoArray.java @@ -50,7 +50,7 @@ public int size() { @Override public StdData get(int idx) { - Block sourceBlock = _mutable == null ? _block : _mutable; + Block sourceBlock = _mutable == null ? _block : _mutable.build(); int position = TrinoWrapper.checkedIndexToBlockPosition(sourceBlock, idx); Object element = readNativeValue(_elementType, sourceBlock, position); return TrinoWrapper.createStdData(element, _elementType, _stdFactory); @@ -77,7 +77,7 @@ public void setUnderlyingData(Object value) { @Override public Iterator iterator() { return new Iterator() { - Block sourceBlock = _mutable == null ? _block : _mutable; + Block sourceBlock = _mutable == null ? _block : _mutable.build(); int size = TrinoArray.this.size(); int position = 0; diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java index 16893bcc..940c842a 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoFloat.java @@ -8,6 +8,7 @@ import com.linkedin.transport.api.data.StdFloat; import io.trino.spi.block.BlockBuilder; +import static io.trino.spi.type.IntegerType.INTEGER; import static java.lang.Float.*; @@ -36,6 +37,6 @@ public void setUnderlyingData(Object value) { @Override public void writeToBlock(BlockBuilder blockBuilder) { - blockBuilder.writeInt(floatToIntBits(_float)); + INTEGER.writeInt(blockBuilder, floatToIntBits(_float)); } } diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java index 73c74637..c3330a2d 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoMap.java @@ -1,7 +1,6 @@ -/** - * Copyright 2018 LinkedIn Corporation. All rights reserved. +/* + * Copyright 2018 LinkedIn Corporation. * Licensed under the BSD-2 Clause license. - * See LICENSE in the project root for license information. */ package com.linkedin.transport.trino.data; @@ -16,98 +15,105 @@ import io.trino.spi.TrinoException; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.MapHashTables.HashBuildMode; +import io.trino.spi.block.SqlMap; import io.trino.spi.function.OperatorType; import io.trino.spi.type.MapType; import io.trino.spi.type.Type; + import java.lang.invoke.MethodHandle; +import java.lang.reflect.Method; import java.util.AbstractCollection; import java.util.AbstractSet; import java.util.Collection; import java.util.Iterator; import java.util.Set; -import static io.trino.spi.StandardErrorCode.*; -import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL; import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN; -import static io.trino.spi.type.TypeUtils.*; - +import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.type.TypeUtils.readNativeValue; public class TrinoMap extends TrinoData implements StdMap { + private final Type _keyType; + private final Type _valueType; + private final MapType _mapType; + private final MethodHandle _keyEqualsMethod; + private final StdFactory _stdFactory; - final Type _keyType; - final Type _valueType; - final Type _mapType; - final MethodHandle _keyEqualsMethod; - final StdFactory _stdFactory; - Block _block; - - public TrinoMap(Type mapType, StdFactory stdFactory) { - BlockBuilder mutable = mapType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); - mutable.beginBlockEntry(); - mutable.closeEntry(); - _block = ((MapType) mapType).getObject(mutable.build(), 0); - - _keyType = ((MapType) mapType).getKeyType(); - _valueType = ((MapType) mapType).getValueType(); - _mapType = mapType; + private SqlMap _map; + public TrinoMap(MapType mapType, StdFactory stdFactory) { + _mapType = mapType; + _keyType = mapType.getKeyType(); + _valueType = mapType.getValueType(); _stdFactory = stdFactory; _keyEqualsMethod = ((TrinoFactory) stdFactory).getOperatorHandle( - OperatorType.EQUAL, ImmutableList.of(_keyType, _keyType), simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)); + OperatorType.EQUAL, + ImmutableList.of(_keyType, _keyType), + simpleConvention(NULLABLE_RETURN, NEVER_NULL, NEVER_NULL)); + + // Start with an empty SqlMap + Block emptyKeys = _keyType.createBlockBuilder(null, 0).build(); + Block emptyValues = _valueType.createBlockBuilder(null, 0).build(); + _map = new SqlMap(mapType, HashBuildMode.STRICT_EQUALS, emptyKeys, emptyValues); } - public TrinoMap(Block block, Type mapType, StdFactory stdFactory) { + public TrinoMap(SqlMap map, MapType mapType, StdFactory stdFactory) { this(mapType, stdFactory); - _block = block; + _map = map; } @Override public int size() { - return _block.getPositionCount() / 2; + return keyBlock().getPositionCount(); } @Override public StdData get(StdData key) { Object trinoKey = ((PlatformData) key).getUnderlyingData(); - int i = seekKey(trinoKey); - if (i != -1) { - Object value = readNativeValue(_valueType, _block, i); - StdData stdValue = TrinoWrapper.createStdData(value, _valueType, _stdFactory); - return stdValue; - } else { + int idx = seekKeyIndex(trinoKey); + if (idx == -1) { return null; } + Object value = readNativeValue(_valueType, valueBlock(), idx); + return TrinoWrapper.createStdData(value, _valueType, _stdFactory); } - // TODO: Do not copy the _mutable BlockBuilder on every update. As long as updates are append-only or for fixed-size - // types, we can skip copying. @Override public void put(StdData key, StdData value) { - BlockBuilder mutable = _mapType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); - BlockBuilder entryBuilder = mutable.beginBlockEntry(); Object trinoKey = ((PlatformData) key).getUnderlyingData(); - int valuePosition = seekKey(trinoKey); - for (int i = 0; i < _block.getPositionCount(); i += 2) { - // Write the current key to the map - _keyType.appendTo(_block, i, entryBuilder); - // Find out if we need to change the corresponding value - if (i == valuePosition - 1) { - // Use the user-supplied value - ((TrinoData) value).writeToBlock(entryBuilder); + int existingIndex = seekKeyIndex(trinoKey); + + int n = size(); + int newSize = (existingIndex == -1) ? n + 1 : n; + + BlockBuilder keyBuilder = _keyType.createBlockBuilder(null, newSize); + BlockBuilder valueBuilder = _valueType.createBlockBuilder(null, newSize); + + // copy existing entries, replacing value if key matches + for (int i = 0; i < n; i++) { + _keyType.appendTo(keyBlock(), i, keyBuilder); + if (i == existingIndex) { + ((TrinoData) value).writeToBlock(valueBuilder); } else { - // Use the existing value in original _block - _valueType.appendTo(_block, i + 1, entryBuilder); + _valueType.appendTo(valueBlock(), i, valueBuilder); } } - if (valuePosition == -1) { - ((TrinoData) key).writeToBlock(entryBuilder); - ((TrinoData) value).writeToBlock(entryBuilder); + + // append new entry if key not present + if (existingIndex == -1) { + ((TrinoData) key).writeToBlock(keyBuilder); + ((TrinoData) value).writeToBlock(valueBuilder); } - mutable.closeEntry(); - _block = ((MapType) _mapType).getObject(mutable.build(), 0); + _map = new SqlMap(_mapType, HashBuildMode.STRICT_EQUALS, keyBuilder.build(), valueBuilder.build()); + } + + @Override + public boolean containsKey(StdData key) { + return get(key) != null; } public Set keySet() { @@ -115,21 +121,17 @@ public Set keySet() { @Override public Iterator iterator() { return new Iterator() { - int i = -2; - - @Override - public boolean hasNext() { - return !(i + 2 == size() * 2); + int i = -1; + @Override public boolean hasNext() { + return i + 1 < size(); } - - @Override - public StdData next() { - i += 2; - return TrinoWrapper.createStdData(readNativeValue(_keyType, _block, i), _keyType, _stdFactory); + @Override public StdData next() { + i++; + Object k = readNativeValue(_keyType, keyBlock(), i); + return TrinoWrapper.createStdData(k, _keyType, _stdFactory); } }; } - @Override public int size() { return TrinoMap.this.size(); @@ -140,25 +142,20 @@ public int size() { @Override public Collection values() { return new AbstractCollection() { - @Override public Iterator iterator() { return new Iterator() { - int i = -2; - - @Override - public boolean hasNext() { - return !(i + 2 == size() * 2); + int i = -1; + @Override public boolean hasNext() { + return i + 1 < size(); } - - @Override - public StdData next() { - i += 2; - return TrinoWrapper.createStdData(readNativeValue(_valueType, _block, i + 1), _valueType, _stdFactory); + @Override public StdData next() { + i++; + Object v = readNativeValue(_valueType, valueBlock(), i); + return TrinoWrapper.createStdData(v, _valueType, _stdFactory); } }; } - @Override public int size() { return TrinoMap.this.size(); @@ -167,25 +164,26 @@ public int size() { } @Override - public boolean containsKey(StdData key) { - return get(key) != null; + public Object getUnderlyingData() { + return _map; } @Override - public Object getUnderlyingData() { - return _block; + public void setUnderlyingData(Object value) { + _map = (SqlMap) value; } @Override - public void setUnderlyingData(Object value) { - _block = (Block) value; + public void writeToBlock(BlockBuilder blockBuilder) { + _mapType.writeObject(blockBuilder, _map); } - private int seekKey(Object key) { - for (int i = 0; i < _block.getPositionCount(); i += 2) { + private int seekKeyIndex(Object key) { + Block keys = keyBlock(); + for (int i = 0; i < keys.getPositionCount(); i++) { try { - if ((boolean) _keyEqualsMethod.invoke(readNativeValue(_keyType, _block, i), key)) { - return i + 1; + if ((boolean) _keyEqualsMethod.invoke(readNativeValue(_keyType, keys, i), key)) { + return i; } } catch (Throwable t) { Throwables.propagateIfInstanceOf(t, Error.class); @@ -196,8 +194,29 @@ private int seekKey(Object key) { return -1; } - @Override - public void writeToBlock(BlockBuilder blockBuilder) { - _mapType.writeObject(blockBuilder, _block); + private Block keyBlock() { + return getBlock(_map, /*key=*/true); + } + + private Block valueBlock() { + return getBlock(_map, /*key=*/false); + } + + /** + * SqlMap getters differ slightly across Trino versions (getKeyBlock()/getValueBlock() vs keyBlock()/valueBlock()). + * Use a tiny reflective shim so this class compiles against either. + */ + private static Block getBlock(SqlMap map, boolean key) { + try { + Method m; + try { + m = SqlMap.class.getMethod(key ? "getRawKeyBlock" : "getRawValueBlock"); + } catch (NoSuchMethodException e) { + m = SqlMap.class.getMethod(key ? "keyBlock" : "valueBlock"); + } + return (Block) m.invoke(map); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Unable to access SqlMap blocks", e); + } } } diff --git a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java index c94ae335..cb854a3c 100644 --- a/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java +++ b/transportable-udfs-trino/src/main/java/com/linkedin/transport/trino/data/TrinoStruct.java @@ -11,147 +11,184 @@ import com.linkedin.transport.trino.TrinoWrapper; import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; -import io.trino.spi.block.BlockBuilderStatus; -import io.trino.spi.block.PageBuilderStatus; +import io.trino.spi.block.SqlRow; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; + import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; -import static io.trino.spi.type.TypeUtils.*; - +import static io.trino.spi.type.TypeUtils.readNativeValue; public class TrinoStruct extends TrinoData implements StdStruct { - final RowType _rowType; - final StdFactory _stdFactory; - Block _block; + private final RowType rowType; + private final StdFactory stdFactory; + + // Trino v446+: ROW values are represented as SqlRow + private SqlRow rowData; public TrinoStruct(Type rowType, StdFactory stdFactory) { - _rowType = (RowType) rowType; - _stdFactory = stdFactory; + this.rowType = (RowType) rowType; + this.stdFactory = stdFactory; } - public TrinoStruct(Block block, Type rowType, StdFactory stdFactory) { + /** Prefer using this ctor if you already have a SqlRow from Trino. */ + public TrinoStruct(SqlRow sqlRow, Type rowType, StdFactory stdFactory) { this(rowType, stdFactory); - _block = block; + this.rowData = sqlRow; } public TrinoStruct(List fieldTypes, StdFactory stdFactory) { - _stdFactory = stdFactory; - _rowType = RowType.anonymous(fieldTypes); + this.stdFactory = stdFactory; + this.rowType = RowType.anonymous(fieldTypes); } public TrinoStruct(List fieldNames, List fieldTypes, StdFactory stdFactory) { - _stdFactory = stdFactory; + this.stdFactory = stdFactory; List fields = IntStream.range(0, fieldNames.size()) .mapToObj(i -> new RowType.Field(Optional.ofNullable(fieldNames.get(i)), fieldTypes.get(i))) .collect(Collectors.toList()); - _rowType = RowType.from(fields); + this.rowType = RowType.from(fields); } @Override public StdData getField(int index) { - int position = TrinoWrapper.checkedIndexToBlockPosition(_block, index); - if (position == -1) { + if (rowData == null) { return null; } - Type elementType = _rowType.getFields().get(position).getType(); - Object element = readNativeValue(elementType, _block, position); - return TrinoWrapper.createStdData(element, elementType, _stdFactory); + int offset = rowData.getRawIndex(); + Type fieldType = rowType.getFields().get(index).getType(); + Block fieldBlock = rowData.getRawFieldBlock(index); + Object element = readNativeValue(fieldType, fieldBlock, offset); + return TrinoWrapper.createStdData(element, fieldType, stdFactory); } @Override public StdData getField(String name) { - int index = -1; - Type elementType = null; - int i = 0; - for (RowType.Field field : _rowType.getFields()) { - if (field.getName().isPresent() && name.equals(field.getName().get())) { - index = i; - elementType = field.getType(); + if (rowData == null) { + return null; + } + int idx = -1; + Type t = null; + for (int i = 0; i < rowType.getFields().size(); i++) { + var f = rowType.getFields().get(i); + if (f.getName().isPresent() && name.equals(f.getName().get())) { + idx = i; + t = f.getType(); break; } - i++; } - if (index == -1) { + if (idx == -1) { return null; } - Object element = readNativeValue(elementType, _block, index); - return TrinoWrapper.createStdData(element, elementType, _stdFactory); + int offset = rowData.getRawIndex(); + Block fieldBlock = rowData.getRawFieldBlock(idx); + Object element = readNativeValue(t, fieldBlock, offset); + return TrinoWrapper.createStdData(element, t, stdFactory); } @Override public void setField(int index, StdData value) { - // TODO: This is not the right way to get this object. The status should be passed in from the invocation of the - // function and propagated to here. See PRESTO-1359 for more details. - BlockBuilderStatus blockBuilderStatus = new PageBuilderStatus().createBlockBuilderStatus(); - BlockBuilder mutable = _rowType.createBlockBuilder(blockBuilderStatus, 1); - BlockBuilder rowBlockBuilder = mutable.beginBlockEntry(); - int i = 0; - for (RowType.Field field : _rowType.getFields()) { + int fieldCount = rowType.getFields().size(); + List fieldTypes = rowType.getTypeParameters(); + Block[] fieldBlocks = new Block[fieldCount]; + + int existingOffset = (rowData == null) ? 0 : rowData.getRawIndex(); + + for (int i = 0; i < fieldCount; i++) { + Type ft = fieldTypes.get(i); + BlockBuilder bb = ft.createBlockBuilder(null, 1); + if (i == index) { - ((TrinoData) value).writeToBlock(rowBlockBuilder); + ((TrinoData) value).writeToBlock(bb); } else { - if (_block == null) { - rowBlockBuilder.appendNull(); + if (rowData == null) { + bb.appendNull(); } else { - field.getType().appendTo(_block, i, rowBlockBuilder); + // copy existing value at this row's offset + ft.appendTo(rowData.getRawFieldBlock(i), existingOffset, bb); } } - i++; + fieldBlocks[i] = bb.build(); } - mutable.closeEntry(); - _block = _rowType.getObject(mutable.build(), 0); + + // Build a single-row SqlRow at offset 0 + this.rowData = new SqlRow(0, fieldBlocks); } @Override public void setField(String name, StdData value) { - BlockBuilder mutable = _rowType.createBlockBuilder(new PageBuilderStatus().createBlockBuilderStatus(), 1); - BlockBuilder rowBlockBuilder = mutable.beginBlockEntry(); - int i = 0; - for (RowType.Field field : _rowType.getFields()) { - if (field.getName().isPresent() && name.equals(field.getName().get())) { - ((TrinoData) value).writeToBlock(rowBlockBuilder); + int fieldCount = rowType.getFields().size(); + List fieldTypes = rowType.getTypeParameters(); + Block[] fieldBlocks = new Block[fieldCount]; + + int targetIndex = -1; + for (int i = 0; i < fieldCount; i++) { + var f = rowType.getFields().get(i); + if (f.getName().isPresent() && name.equals(f.getName().get())) { + targetIndex = i; + break; + } + } + if (targetIndex == -1) { + // Unknown field name; treat as no-op + return; + } + + int existingOffset = (rowData == null) ? 0 : rowData.getRawIndex(); + + for (int i = 0; i < fieldCount; i++) { + Type ft = fieldTypes.get(i); + BlockBuilder bb = ft.createBlockBuilder(null, 1); + + if (i == targetIndex) { + ((TrinoData) value).writeToBlock(bb); } else { - if (_block == null) { - rowBlockBuilder.appendNull(); + if (rowData == null) { + bb.appendNull(); } else { - field.getType().appendTo(_block, i, rowBlockBuilder); + ft.appendTo(rowData.getRawFieldBlock(i), existingOffset, bb); } } - i++; + fieldBlocks[i] = bb.build(); } - mutable.closeEntry(); - _block = _rowType.getObject(mutable.build(), 0); + + this.rowData = new SqlRow(0, fieldBlocks); } @Override public List fields() { - ArrayList fields = new ArrayList<>(); - for (int i = 0; i < _block.getPositionCount(); i++) { - Type elementType = _rowType.getFields().get(i).getType(); - Object element = readNativeValue(elementType, _block, i); - fields.add(TrinoWrapper.createStdData(element, elementType, _stdFactory)); + ArrayList out = new ArrayList<>(); + if (rowData == null) { + return out; + } + int offset = rowData.getRawIndex(); + int count = rowType.getFields().size(); + for (int i = 0; i < count; i++) { + Type t = rowType.getFields().get(i).getType(); + Block fieldBlock = rowData.getRawFieldBlock(i); + Object element = readNativeValue(t, fieldBlock, offset); + out.add(TrinoWrapper.createStdData(element, t, stdFactory)); } - return fields; + return out; } @Override public Object getUnderlyingData() { - return _block; + return rowData; } @Override public void setUnderlyingData(Object value) { - _block = (Block) value; + this.rowData = (SqlRow) value; } @Override public void writeToBlock(BlockBuilder blockBuilder) { - _rowType.writeObject(blockBuilder, getUnderlyingData()); + rowType.writeObject(blockBuilder, rowData); } } diff --git a/transportable-udfs-type-system/build.gradle b/transportable-udfs-type-system/build.gradle index 26ae6211..c8bbf5fb 100644 --- a/transportable-udfs-type-system/build.gradle +++ b/transportable-udfs-type-system/build.gradle @@ -8,7 +8,7 @@ dependencies { } task jarTests(type: Jar, dependsOn: testClasses) { - classifier = 'tests' + archiveClassifier.set('tests') from sourceSets.test.output }