diff --git a/pom.xml b/pom.xml index 772df8bc6da..65a529a5c24 100644 --- a/pom.xml +++ b/pom.xml @@ -62,6 +62,7 @@ zeppelin-jupyter-interpreter-shaded groovy spark + spark-connect spark-submit submarine markdown diff --git a/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml b/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml index 09ed9a39013..34e658fc2c8 100644 --- a/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml +++ b/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml @@ -4,29 +4,55 @@ channels: - defaults dependencies: - python >=3.9,<3.10 - - pyspark=3.3.2 + - pyspark=3.5 - pycodestyle - - scipy + # --- Core data libraries --- + - pandas - numpy + - scipy + - pyarrow + # --- Spark Connect protocol --- - grpcio - protobuf + # --- HTTP / networking --- + - requests + - urllib3 + # --- File format support --- + - openpyxl + - xlrd + - pyyaml + - tabulate + # --- GCP access --- + - google-cloud-storage + - google-auth + - gcsfs + # --- Visualization --- + - matplotlib + - seaborn + - plotly + - plotnine + - altair + - vega_datasets + - hvplot + # --- SQL on pandas --- - pandasql + # --- ML --- + - scikit-learn + - xgboost + # --- IPython / kernel --- - ipython - ipykernel - jupyter_client - - hvplot - - plotnine - - seaborn + # --- Data connectors --- - intake - intake-parquet - intake-xarray - - altair - - vega_datasets - - plotly + # --- pip-only packages --- - pip - pip: - # works for regular pip packages - bkzep==0.6.1 + - delta-spark==3.2.1 + # --- R support --- - r-base=3 - r-data.table - r-evaluate diff --git a/spark-connect/pom.xml b/spark-connect/pom.xml new file mode 100644 index 00000000000..cb2a5bd9dc3 --- /dev/null +++ b/spark-connect/pom.xml @@ -0,0 +1,130 @@ + + + + + 4.0.0 + + + zeppelin-interpreter-parent + org.apache.zeppelin + 0.11.2 + ../zeppelin-interpreter-parent/pom.xml + + + spark-connect-interpreter + jar + Zeppelin: Spark Connect Interpreter + Zeppelin Spark Connect support via gRPC client + + + spark-connect + 3.5.3 + 2.12 + + + + + org.apache.spark + spark-connect-client-jvm_${spark.scala.binary.version} + ${spark.connect.version} + + + + org.apache.zeppelin + zeppelin-python + ${project.version} + + + + org.apache.commons + commons-lang3 + + + + org.mockito + mockito-core + test + + + + + + + maven-resources-plugin + + + maven-shade-plugin + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + reference.conf + + + + + org.apache.zeppelin:zeppelin-interpreter-shaded + + + + + io.netty + org.apache.zeppelin.spark.connect.io.netty + + + com.google + org.apache.zeppelin.spark.connect.com.google + + + io.grpc + org.apache.zeppelin.spark.connect.io.grpc + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + + + + + spark-connect-3.5 + + true + + + 3.5.3 + + + + + diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/IPySparkConnectInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/IPySparkConnectInterpreter.java new file mode 100644 index 00000000000..d6adef101e6 --- /dev/null +++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/IPySparkConnectInterpreter.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.python.IPythonInterpreter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Properties; + +/** + * PySpark Connect Interpreter which uses IPython underlying. + * Reuses the Java SparkSession from SparkConnectInterpreter via Py4j. + */ +public class IPySparkConnectInterpreter extends IPythonInterpreter { + + private static final Logger LOGGER = LoggerFactory.getLogger(IPySparkConnectInterpreter.class); + + private SparkConnectInterpreter sparkConnectInterpreter; + private PySparkConnectInterpreter pySparkConnectInterpreter; + private boolean opened = false; + private InterpreterContext curIntpContext; + + public IPySparkConnectInterpreter(Properties property) { + super(property); + } + + @Override + public synchronized void open() throws InterpreterException { + if (opened) { + return; + } + + this.sparkConnectInterpreter = + getInterpreterInTheSameSessionByClassName(SparkConnectInterpreter.class); + this.pySparkConnectInterpreter = + getInterpreterInTheSameSessionByClassName(PySparkConnectInterpreter.class, false); + + sparkConnectInterpreter.open(); + + setProperty("zeppelin.python", pySparkConnectInterpreter.getPythonExec()); + setUseBuiltinPy4j(true); + setAdditionalPythonInitFile("python/zeppelin_isparkconnect.py"); + super.open(); + opened = true; + } + + @Override + public org.apache.zeppelin.interpreter.InterpreterResult interpret(String st, + InterpreterContext context) throws InterpreterException { + InterpreterContext.set(context); + this.curIntpContext = context; + String setInptContextStmt = "intp.setInterpreterContextInPython()"; + org.apache.zeppelin.interpreter.InterpreterResult result = + super.interpret(setInptContextStmt, context); + if (result.code().equals(org.apache.zeppelin.interpreter.InterpreterResult.Code.ERROR)) { + return new org.apache.zeppelin.interpreter.InterpreterResult( + org.apache.zeppelin.interpreter.InterpreterResult.Code.ERROR, + "Fail to setCurIntpContext"); + } + + return super.interpret(st, context); + } + + public void setInterpreterContextInPython() { + InterpreterContext.set(curIntpContext); + } + + public SparkSession getSparkSession() { + if (sparkConnectInterpreter != null) { + return sparkConnectInterpreter.getSparkSession(); + } + return null; + } + + @Override + public void cancel(InterpreterContext context) throws InterpreterException { + super.cancel(context); + if (sparkConnectInterpreter != null) { + sparkConnectInterpreter.cancel(context); + } + } + + @Override + public void close() throws InterpreterException { + LOGGER.info("Close IPySparkConnectInterpreter (opened={})", opened); + try { + super.close(); + } finally { + opened = false; + sparkConnectInterpreter = null; + pySparkConnectInterpreter = null; + LOGGER.info("IPySparkConnectInterpreter closed and state reset — ready for re-open"); + } + } + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + return 0; + } + + public int getMaxResult() { + if (sparkConnectInterpreter != null) { + return sparkConnectInterpreter.getMaxResult(); + } + return 1000; + } + + @SuppressWarnings("unchecked") + public String formatDataFrame(Object df, int maxResult) { + return SparkConnectUtils.showDataFrame((Dataset) df, maxResult); + } +} diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/NotebookLockManager.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/NotebookLockManager.java new file mode 100644 index 00000000000..8f39e937cc9 --- /dev/null +++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/NotebookLockManager.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Shared utility class for managing notebook-level locks. + * Ensures that only one query executes at a time per notebook, + * regardless of which interpreter (SparkConnectInterpreter or SparkConnectSqlInterpreter) is used. + */ +public class NotebookLockManager { + + // Locks per notebook to ensure one query at a time per notebook + private static final ConcurrentHashMap notebookLocks = + new ConcurrentHashMap<>(); + + /** + * Get or create a lock for the specified notebook. + * Uses fair locking to ensure FIFO ordering of query execution. + * + * @param noteId The notebook ID + * @return The lock for this notebook + */ + public static ReentrantLock getNotebookLock(String noteId) { + return notebookLocks.computeIfAbsent(noteId, + k -> new ReentrantLock(true)); // Fair lock for FIFO ordering + } + + /** + * Remove the lock for a notebook (cleanup when notebook is closed). + * + * @param noteId The notebook ID + */ + public static void removeNotebookLock(String noteId) { + notebookLocks.remove(noteId); + } + + /** + * Get the number of active notebook locks (for monitoring/debugging). + * + * @return The number of active locks + */ + public static int getActiveLockCount() { + return notebookLocks.size(); + } +} diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/PySparkConnectInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/PySparkConnectInterpreter.java new file mode 100644 index 00000000000..1b79a4ac164 --- /dev/null +++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/PySparkConnectInterpreter.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.commons.lang3.StringUtils; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.python.PythonInterpreter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Map; +import java.util.Properties; + +/** + * PySpark interpreter for Spark Connect. + * Reuses the Java SparkSession from SparkConnectInterpreter (via Py4j) + * so that the Python side uses the same 3.5.x-compatible client as the SQL interpreter. + */ +public class PySparkConnectInterpreter extends PythonInterpreter { + + private static final Logger LOGGER = LoggerFactory.getLogger(PySparkConnectInterpreter.class); + + private SparkConnectInterpreter sparkConnectInterpreter; + private InterpreterContext curIntpContext; + + public PySparkConnectInterpreter(Properties property) { + super(property); + this.useBuiltinPy4j = true; + } + + @Override + public void open() throws InterpreterException { + setProperty("zeppelin.python.useIPython", + getProperty("zeppelin.pyspark.connect.useIPython", "true")); + + this.sparkConnectInterpreter = + getInterpreterInTheSameSessionByClassName(SparkConnectInterpreter.class); + + // Ensure the Java SparkSession is ready before starting Python + sparkConnectInterpreter.open(); + + // Log Python executable resolution (matching Spark's behavior) + String pythonExec = getPythonExec(); + LOGGER.info("Python executable resolved: {}", pythonExec); + + // Call super.open() - let PythonInterpreter handle Python process launch + // This matches Spark's PySparkInterpreter behavior - no pre-validation + super.open(); + + if (!useIPython()) { + try { + bootstrapInterpreter("python/zeppelin_sparkconnect.py"); + } catch (IOException e) { + LOGGER.error("Fail to bootstrap spark connect", e); + throw new InterpreterException("Fail to bootstrap spark connect", e); + } + } + } + + @Override + public void close() throws InterpreterException { + LOGGER.info("Close PySparkConnectInterpreter"); + super.close(); + } + + @Override + protected org.apache.zeppelin.python.IPythonInterpreter getIPythonInterpreter() + throws InterpreterException { + return getInterpreterInTheSameSessionByClassName(IPySparkConnectInterpreter.class, false); + } + + @Override + public org.apache.zeppelin.interpreter.InterpreterResult interpret(String st, + InterpreterContext context) throws InterpreterException { + curIntpContext = context; + return super.interpret(st, context); + } + + @Override + protected void preCallPython(InterpreterContext context) { + callPython(new PythonInterpretRequest( + "intp.setInterpreterContextInPython()", false, false)); + } + + public void setInterpreterContextInPython() { + InterpreterContext.set(curIntpContext); + } + + @Override + protected Map setupPythonEnv() throws IOException { + Map env = super.setupPythonEnv(); + + // Set PYSPARK_PYTHON environment variable (following Spark's pattern) + // This ensures Python subprocesses can find the correct Python executable + String pythonExec = getPythonExec(); + env.put("PYSPARK_PYTHON", pythonExec); + LOGGER.info("Set PYSPARK_PYTHON: {}", pythonExec); + + // Set up LD_LIBRARY_PATH for conda installations + // This is critical - conda Python binaries depend on libraries in conda/lib + setupCondaLibraryPath(env, pythonExec); + + LOGGER.info("LD_LIBRARY_PATH: {}", env.get("LD_LIBRARY_PATH")); + return env; + } + + /** + * Get Python executable following Spark's PySpark detection pattern exactly: + * 1. spark.pyspark.driver.python (from Spark Connect properties) + * 2. spark.pyspark.python (from Spark Connect properties) + * 3. PYSPARK_DRIVER_PYTHON (environment variable) + * 4. PYSPARK_PYTHON (environment variable) + * 5. zeppelin.python (Zeppelin property) - if set, validate it + * 6. Default to "python" (let system PATH handle it, just like Spark does) + * + * This matches Spark's PySparkInterpreter.getPythonExec() behavior. + * Spark defaults to "python" and relies on system PATH - we do the same. + */ + @Override + protected String getPythonExec() { + // Priority 1: spark.pyspark.driver.python (Spark Connect property) + String driverPython = getProperty("spark.pyspark.driver.python", ""); + if (StringUtils.isNotBlank(driverPython)) { + LOGGER.info("Using Python executable from spark.pyspark.driver.python: {}", driverPython); + // Don't validate here - let ProcessBuilder fail naturally if invalid + // This matches Spark's behavior - it trusts the configuration + return driverPython; + } + + // Priority 2: spark.pyspark.python (Spark Connect property) + String pysparkPython = getProperty("spark.pyspark.python", ""); + if (StringUtils.isNotBlank(pysparkPython)) { + LOGGER.info("Using Python executable from spark.pyspark.python: {}", pysparkPython); + return pysparkPython; + } + + // Priority 3: PYSPARK_DRIVER_PYTHON (environment variable) + String envDriverPython = System.getenv("PYSPARK_DRIVER_PYTHON"); + if (StringUtils.isNotBlank(envDriverPython)) { + LOGGER.info("Using Python executable from PYSPARK_DRIVER_PYTHON: {}", envDriverPython); + return envDriverPython; + } + + // Priority 4: PYSPARK_PYTHON (environment variable) + String envPysparkPython = System.getenv("PYSPARK_PYTHON"); + if (StringUtils.isNotBlank(envPysparkPython)) { + LOGGER.info("Using Python executable from PYSPARK_PYTHON: {}", envPysparkPython); + return envPysparkPython; + } + + // Priority 5: zeppelin.python (Zeppelin property) - only if explicitly set + String zeppelinPython = getProperty("zeppelin.python", ""); + if (StringUtils.isNotBlank(zeppelinPython)) { + LOGGER.info("Using Python executable from zeppelin.python property: {}", zeppelinPython); + return zeppelinPython; + } + + // Priority 6: Default to "python" (let system PATH handle it, just like Spark) + // Spark's PySparkInterpreter defaults to "python" - we do the same + // This relies on system PATH to find Python, no explicit path needed + LOGGER.info("No Python executable configured, defaulting to 'python' (will use system PATH)"); + return "python"; + } + + private void setupCondaLibraryPath(Map env, String pythonExec) { + // If python path contains "/conda/", add conda lib to LD_LIBRARY_PATH + // This only applies if an explicit conda path is configured + if (pythonExec != null && pythonExec.contains("/conda/")) { + // Extract conda base path (e.g., /opt/conda/default from /opt/conda/default/bin/python3) + int binIndex = pythonExec.indexOf("/bin/"); + if (binIndex > 0) { + String condaBase = pythonExec.substring(0, binIndex); + String condaLib = condaBase + "/lib"; + java.io.File libDir = new java.io.File(condaLib); + if (libDir.exists() && libDir.isDirectory()) { + String ldLibraryPath = env.getOrDefault("LD_LIBRARY_PATH", ""); + if (ldLibraryPath.isEmpty()) { + env.put("LD_LIBRARY_PATH", condaLib); + } else if (!ldLibraryPath.contains(condaLib)) { + env.put("LD_LIBRARY_PATH", condaLib + ":" + ldLibraryPath); + } + LOGGER.info("Added conda lib directory to LD_LIBRARY_PATH: {}", condaLib); + } + } + } + // If using "python" from PATH, don't modify LD_LIBRARY_PATH + // Let the system handle it - Python should already be configured correctly + } + + /** + * Exposes the Java SparkSession to the Python process via Py4j gateway. + * This is the same session used by SparkConnectSqlInterpreter. + */ + public SparkSession getSparkSession() { + if (sparkConnectInterpreter != null) { + return sparkConnectInterpreter.getSparkSession(); + } + return null; + } + + public SparkConnectInterpreter getSparkConnectInterpreter() { + return sparkConnectInterpreter; + } + + public int getMaxResult() { + if (sparkConnectInterpreter != null) { + return sparkConnectInterpreter.getMaxResult(); + } + return Integer.parseInt(getProperty("zeppelin.spark.maxResult", "1000")); + } + + @SuppressWarnings("unchecked") + public String formatDataFrame(Object df, int maxResult) { + return SparkConnectUtils.showDataFrame((Dataset) df, maxResult); + } +} diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectInterpreter.java new file mode 100644 index 00000000000..b8ae477a0e8 --- /dev/null +++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectInterpreter.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.zeppelin.interpreter.AbstractInterpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.ZeppelinContext; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.interpreter.util.SqlSplitter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Spark Connect interpreter for Zeppelin. + * Connects to a remote Spark cluster via the Spark Connect gRPC protocol. + * + * Session & Concurrency Model: + *
    + *
  • Max Spark Connect sessions per user is capped + * ({@code zeppelin.spark.connect.maxSessionsPerUser}, default 5). + * Each interpreter instance creates one session; Zeppelin's binding mode + * (per-user / per-note / scoped / isolated) controls how many instances exist.
  • + *
  • Notebooks are unlimited -- users may open as many as they like.
  • + *
  • Within a single notebook only one query executes at a time + * (per-notebook fair lock via {@link NotebookLockManager}).
  • + *
+ */ +public class SparkConnectInterpreter extends AbstractInterpreter { + + private static final Logger LOGGER = LoggerFactory.getLogger(SparkConnectInterpreter.class); + + /** user -> number of active SparkSession instances owned by that user. */ + private static final ConcurrentHashMap userSessionCount = + new ConcurrentHashMap<>(); + + private static final int DEFAULT_MAX_SESSIONS_PER_USER = 5; + + private SparkSession sparkSession; + private SqlSplitter sqlSplitter; + private int maxResult; + private String currentUser; + private volatile boolean sessionSlotAcquired = false; + + public SparkConnectInterpreter(Properties properties) { + super(properties); + } + + @Override + public synchronized void open() throws InterpreterException { + if (sparkSession != null) { + LOGGER.warn("open() called but sparkSession is already active — skipping. " + + "Call close() first to restart the interpreter."); + return; + } + + try { + currentUser = getUserName(); + if (StringUtils.isBlank(currentUser)) { + currentUser = "anonymous"; + } + + int maxSessions = Integer.parseInt( + getProperty("zeppelin.spark.connect.maxSessionsPerUser", + String.valueOf(DEFAULT_MAX_SESSIONS_PER_USER))); + + if (!acquireSessionSlot(currentUser, maxSessions)) { + throw new InterpreterException( + String.format("User '%s' already has %d active Spark Connect sessions " + + "(max %d). Close an existing interpreter before opening a new one.", + currentUser, maxSessions, maxSessions)); + } + sessionSlotAcquired = true; + + LOGGER.info("Opening SparkConnectInterpreter for user: {} (session slot {}/{})", + currentUser, userSessionCount.getOrDefault(currentUser, 0), maxSessions); + + String remoteUrl = SparkConnectUtils.buildConnectionString(getProperties(), currentUser); + LOGGER.info("Connecting to Spark Connect server at: {}", + remoteUrl.replaceAll("token=[^;]*", "token=[REDACTED]") + .replaceAll("user_id=[^;]*", "user_id=[REDACTED]")); + + // Clear the thread-local active session on the Spark Connect client so that + // getOrCreate() creates a fresh remote session rather than reusing the previous + // closed one. We intentionally do NOT call clearDefaultSession() because that + // is a JVM-global operation and would disrupt other interpreter instances that + // are concurrently active in the same process. + try { + SparkSession.clearActiveSession(); + LOGGER.info("Cleared thread-local active Spark session (safe for multi-interpreter)"); + } catch (Exception e) { + LOGGER.warn("Could not clear active Spark session (non-fatal): {}", e.getMessage()); + } + + SparkSession.Builder builder = SparkSession.builder().remote(remoteUrl); + + String appName = getProperty("spark.app.name", "Zeppelin Spark Connect"); + if (StringUtils.isNotBlank(appName)) { + builder.appName(appName); + } + + String grpcMaxMsgSize = getProperty( + "spark.connect.grpc.maxMessageSize", "134217728"); + builder.config("spark.connect.grpc.maxMessageSize", grpcMaxMsgSize); + + for (Object key : getProperties().keySet()) { + String keyStr = key.toString(); + String value = getProperties().getProperty(keyStr); + if (StringUtils.isNotBlank(value) + && keyStr.startsWith("spark.") + && !keyStr.equals("spark.remote") + && !keyStr.equals("spark.connect.token") + && !keyStr.equals("spark.connect.use_ssl") + && !keyStr.equals("spark.app.name") + && !keyStr.equals("spark.connect.grpc.maxMessageSize")) { + builder.config(keyStr, value); + } + } + + sparkSession = builder.getOrCreate(); + LOGGER.info("Spark Connect session established for user: {}", currentUser); + + maxResult = Integer.parseInt(getProperty("zeppelin.spark.maxResult", "1000")); + sqlSplitter = new SqlSplitter(); + } catch (InterpreterException ie) { + throw ie; + } catch (Exception e) { + if (sessionSlotAcquired) { + releaseSessionSlot(currentUser); + sessionSlotAcquired = false; + } + LOGGER.error("Failed to connect to Spark Connect server", e); + throw new InterpreterException("Failed to connect to Spark Connect server: " + + e.getMessage(), e); + } + } + + @Override + public void close() throws InterpreterException { + LOGGER.info("Closing SparkConnectInterpreter for user: {} (sparkSession={})", + currentUser, sparkSession != null ? "active" : "null"); + if (sparkSession != null) { + try { + sparkSession.close(); + LOGGER.info("Spark Connect session closed for user: {}", currentUser); + } catch (Exception e) { + LOGGER.warn("Error closing Spark Connect session", e); + } finally { + sparkSession = null; + } + } else { + LOGGER.info("close() called but no active sparkSession — nothing to tear down"); + } + if (sessionSlotAcquired) { + releaseSessionSlot(currentUser); + sessionSlotAcquired = false; + } + } + + @Override + public ZeppelinContext getZeppelinContext() { + return null; + } + + @Override + public InterpreterResult internalInterpret(String st, InterpreterContext context) + throws InterpreterException { + if (sparkSession == null) { + return new InterpreterResult(Code.ERROR, + "Spark Connect session is not initialized. Check connection settings."); + } + + String noteId = context.getNoteId(); + if (StringUtils.isBlank(noteId)) { + return new InterpreterResult(Code.ERROR, + "Note ID is missing from interpreter context."); + } + + // Per-notebook lock: only one query at a time inside a notebook + ReentrantLock notebookLock = NotebookLockManager.getNotebookLock(noteId); + notebookLock.lock(); + try { + List sqls = sqlSplitter.splitSql(st); + int limit = Integer.parseInt(context.getLocalProperties().getOrDefault("limit", + String.valueOf(maxResult))); + + boolean useStreaming = Boolean.parseBoolean( + getProperty("zeppelin.spark.connect.streamResults", "false")); + + String curSql = null; + try { + for (String sql : sqls) { + curSql = sql; + if (StringUtils.isBlank(sql)) { + continue; + } + Dataset df = sparkSession.sql(sql); + if (useStreaming) { + SparkConnectUtils.streamDataFrame(df, limit, context.out); + } else { + String result = SparkConnectUtils.showDataFrame(df, limit); + context.out.write(result); + } + } + context.out.flush(); + } catch (Exception e) { + return handleSqlException(e, curSql, context); + } + + return new InterpreterResult(Code.SUCCESS); + } finally { + notebookLock.unlock(); + } + } + + // ---- session-slot helpers (static, shared across all instances) ---- + + /** + * Try to claim one session slot for the user. + * @return true if a slot was available and claimed + */ + private static synchronized boolean acquireSessionSlot(String user, int maxSessions) { + int current = userSessionCount.getOrDefault(user, 0); + if (current >= maxSessions) { + LOGGER.warn("User {} already has {} active Spark Connect sessions (max {})", + user, current, maxSessions); + return false; + } + userSessionCount.put(user, current + 1); + LOGGER.info("Acquired session slot for user {}. Active sessions: {}/{}", + user, current + 1, maxSessions); + return true; + } + + /** + * Release one session slot for the user. + */ + private static synchronized void releaseSessionSlot(String user) { + if (user == null) { + return; + } + int current = userSessionCount.getOrDefault(user, 0); + if (current <= 1) { + userSessionCount.remove(user); + } else { + userSessionCount.put(user, current - 1); + } + LOGGER.info("Released session slot for user {}. Remaining sessions: {}", + user, Math.max(0, current - 1)); + } + + /** Visible for testing. */ + static int getActiveSessionCount(String user) { + return userSessionCount.getOrDefault(user, 0); + } + + private InterpreterResult handleSqlException(Exception e, String sql, + InterpreterContext context) { + try { + LOGGER.error("Error executing SQL: {}", sql, e); + context.out.write("\nError in SQL: " + sql + "\n"); + if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace", "true"))) { + if (e.getCause() != null) { + context.out.write(ExceptionUtils.getStackTrace(e.getCause())); + } else { + context.out.write(ExceptionUtils.getStackTrace(e)); + } + } else { + String msg = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + context.out.write(msg + + "\nSet zeppelin.spark.sql.stacktrace = true to see full stacktrace"); + } + context.out.flush(); + } catch (IOException ex) { + LOGGER.error("Failed to write error output", ex); + } + return new InterpreterResult(Code.ERROR); + } + + @Override + public void cancel(InterpreterContext context) throws InterpreterException { + if (sparkSession != null) { + try { + sparkSession.interruptAll(); + } catch (Exception e) { + LOGGER.warn("Error interrupting Spark Connect session", e); + } + } + } + + @Override + public FormType getFormType() { + return FormType.SIMPLE; + } + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + return 0; + } + + @Override + public List completion(String buf, int cursor, + InterpreterContext interpreterContext) throws InterpreterException { + return new ArrayList<>(); + } + + public SparkSession getSparkSession() { + return sparkSession; + } + + public int getMaxResult() { + return maxResult; + } +} diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreter.java new file mode 100644 index 00000000000..e642455d9aa --- /dev/null +++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreter.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.zeppelin.interpreter.AbstractInterpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.ZeppelinContext; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import org.apache.zeppelin.interpreter.util.SqlSplitter; +import org.apache.zeppelin.scheduler.Scheduler; +import org.apache.zeppelin.scheduler.SchedulerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Spark Connect SQL interpreter for Zeppelin. + * Delegates to SparkConnectInterpreter for the SparkSession, providing + * dedicated SQL execution with concurrent query support. + */ +public class SparkConnectSqlInterpreter extends AbstractInterpreter { + + private static final Logger LOGGER = LoggerFactory.getLogger(SparkConnectSqlInterpreter.class); + + private SparkConnectInterpreter sparkConnectInterpreter; + private SqlSplitter sqlSplitter; + + public SparkConnectSqlInterpreter(Properties properties) { + super(properties); + } + + @Override + public void open() throws InterpreterException { + this.sparkConnectInterpreter = + getInterpreterInTheSameSessionByClassName(SparkConnectInterpreter.class); + this.sqlSplitter = new SqlSplitter(); + } + + @Override + public void close() throws InterpreterException { + sparkConnectInterpreter = null; + } + + @Override + protected boolean isInterpolate() { + return Boolean.parseBoolean(getProperty("zeppelin.spark.sql.interpolation", "false")); + } + + @Override + public ZeppelinContext getZeppelinContext() { + return null; + } + + @Override + public InterpreterResult internalInterpret(String st, InterpreterContext context) + throws InterpreterException { + SparkSession sparkSession = sparkConnectInterpreter.getSparkSession(); + if (sparkSession == null) { + return new InterpreterResult(Code.ERROR, + "Spark Connect session is not initialized. Check connection settings."); + } + + // Get noteId from context for notebook-level synchronization + String noteId = context.getNoteId(); + if (StringUtils.isBlank(noteId)) { + return new InterpreterResult(Code.ERROR, + "Note ID is missing from interpreter context."); + } + + // Get or create lock for this notebook to ensure sequential execution + // This ensures one query at a time per notebook, even with concurrentSQL enabled + ReentrantLock notebookLock = NotebookLockManager.getNotebookLock(noteId); + + // Acquire lock to ensure only one query executes at a time for this notebook + notebookLock.lock(); + try { + List sqls = sqlSplitter.splitSql(st); + int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit", + String.valueOf(sparkConnectInterpreter.getMaxResult()))); + + boolean useStreaming = Boolean.parseBoolean( + getProperty("zeppelin.spark.connect.streamResults", "false")); + + String curSql = null; + try { + for (String sql : sqls) { + curSql = sql; + if (StringUtils.isBlank(sql)) { + continue; + } + Dataset df = sparkSession.sql(sql); + if (useStreaming) { + SparkConnectUtils.streamDataFrame(df, maxResult, context.out); + } else { + String result = SparkConnectUtils.showDataFrame(df, maxResult); + context.out.write(result); + } + } + context.out.flush(); + } catch (Exception e) { + try { + LOGGER.error("Error executing SQL: {}", curSql, e); + context.out.write("\nError in SQL: " + curSql + "\n"); + if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace", "true"))) { + if (e.getCause() != null) { + context.out.write(ExceptionUtils.getStackTrace(e.getCause())); + } else { + context.out.write(ExceptionUtils.getStackTrace(e)); + } + } else { + String msg = e.getCause() != null ? e.getCause().getMessage() : e.getMessage(); + context.out.write(msg + + "\nSet zeppelin.spark.sql.stacktrace = true to see full stacktrace"); + } + context.out.flush(); + } catch (IOException ex) { + LOGGER.error("Failed to write error output", ex); + } + return new InterpreterResult(Code.ERROR); + } + + return new InterpreterResult(Code.SUCCESS); + } finally { + notebookLock.unlock(); + } + } + + @Override + public void cancel(InterpreterContext context) throws InterpreterException { + SparkSession sparkSession = sparkConnectInterpreter.getSparkSession(); + if (sparkSession != null) { + try { + sparkSession.interruptAll(); + } catch (Exception e) { + LOGGER.warn("Error interrupting Spark Connect session", e); + } + } + } + + @Override + public FormType getFormType() { + return FormType.SIMPLE; + } + + @Override + public int getProgress(InterpreterContext context) throws InterpreterException { + return 0; + } + + @Override + public List completion(String buf, int cursor, + InterpreterContext interpreterContext) throws InterpreterException { + return new ArrayList<>(); + } + + @Override + public Scheduler getScheduler() { + if (concurrentSQL()) { + int maxConcurrency = Integer.parseInt( + getProperty("zeppelin.spark.concurrentSQL.max", "10")); + return SchedulerFactory.singleton().createOrGetParallelScheduler( + SparkConnectSqlInterpreter.class.getName() + this.hashCode(), maxConcurrency); + } else { + try { + return getInterpreterInTheSameSessionByClassName( + SparkConnectInterpreter.class, false).getScheduler(); + } catch (InterpreterException e) { + throw new RuntimeException("Failed to get scheduler", e); + } + } + } + + private boolean concurrentSQL() { + return Boolean.parseBoolean(getProperty("zeppelin.spark.concurrentSQL")); + } +} diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectUtils.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectUtils.java new file mode 100644 index 00000000000..ffad1fbb789 --- /dev/null +++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectUtils.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.commons.lang3.StringUtils; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + +public class SparkConnectUtils { + + static final int DEFAULT_TRUNCATE_LENGTH = 256; + + private SparkConnectUtils() { + } + + /** + * Build the Spark Connect connection string from interpreter properties. + * Format: sc://hostname:port[/;param1=val1;param2=val2] + * + * Supports: token, use_ssl, user_id, and any extra params already in the URI. + * Examples: + * sc://localhost:15002 + * sc://localhost:15002/;use_ssl=true;token=abc123;user_id=alice + * sc://ranking-cluster-m:8080 + */ + public static String buildConnectionString(Properties properties) { + return buildConnectionString(properties, null); + } + + /** + * Build the Spark Connect connection string, including user_id so the Spark Connect + * server can attribute the session to the correct user in its own UI. + * + * @param properties interpreter properties + * @param userName the authenticated Zeppelin username; ignored if blank + */ + public static String buildConnectionString(Properties properties, String userName) { + String remote = properties.getProperty("spark.remote", "sc://localhost:15002"); + StringBuilder params = new StringBuilder(); + + String token = properties.getProperty("spark.connect.token", ""); + if (StringUtils.isNotBlank(token)) { + params.append(";token=").append(token); + } + + boolean useSsl = Boolean.parseBoolean( + properties.getProperty("spark.connect.use_ssl", "false")); + if (useSsl) { + params.append(";use_ssl=true"); + } + + if (StringUtils.isNotBlank(userName) && !remote.contains("user_id=")) { + params.append(";user_id=").append(userName); + } + + if (params.length() > 0) { + if (remote.contains(";")) { + remote = remote + params; + } else { + remote = remote + "/" + params; + } + } + return remote; + } + + /** + * Convert a Dataset to Zeppelin's %table format string. + * Applies limit before collecting to prevent OOM on the driver. + * Truncates cell values to avoid excessively wide output. + */ + public static String showDataFrame(Dataset df, int maxResult) { + return showDataFrame(df, maxResult, DEFAULT_TRUNCATE_LENGTH); + } + + public static String showDataFrame(Dataset df, int maxResult, int truncateLength) { + StructType schema = df.schema(); + StructField[] fields = schema.fields(); + + int effectiveLimit = Math.max(1, Math.min(maxResult, 100_000)); + + int estimatedRowSize = Math.max(fields.length * 20, 100); + int estimatedTotalBytes = estimatedRowSize * effectiveLimit; + StringBuilder sb = new StringBuilder(Math.min(estimatedTotalBytes, 10 * 1024 * 1024)); + sb.append("%table "); + + for (int i = 0; i < fields.length; i++) { + if (i > 0) { + sb.append('\t'); + } + sb.append(replaceReservedChars(fields[i].name())); + } + sb.append('\n'); + + List rows = df.limit(effectiveLimit).collectAsList(); + for (Row row : rows) { + for (int i = 0; i < row.length(); i++) { + if (i > 0) { + sb.append('\t'); + } + Object value = row.get(i); + String cellStr = value == null ? "null" : value.toString(); + if (truncateLength > 0 && cellStr.length() > truncateLength) { + cellStr = cellStr.substring(0, truncateLength) + "..."; + } + sb.append(replaceReservedChars(cellStr)); + } + sb.append('\n'); + } + + return sb.toString(); + } + + /** + * Stream a Dataset as Zeppelin %table format directly to an OutputStream, + * avoiding building the entire result in memory. + * Preferred for large result sets. + */ + public static void streamDataFrame(Dataset df, int maxResult, OutputStream out) + throws IOException { + streamDataFrame(df, maxResult, DEFAULT_TRUNCATE_LENGTH, out); + } + + public static void streamDataFrame(Dataset df, int maxResult, + int truncateLength, OutputStream out) throws IOException { + StructType schema = df.schema(); + StructField[] fields = schema.fields(); + + int effectiveLimit = Math.max(1, Math.min(maxResult, 100_000)); + + StringBuilder header = new StringBuilder("%table "); + for (int i = 0; i < fields.length; i++) { + if (i > 0) { + header.append('\t'); + } + header.append(replaceReservedChars(fields[i].name())); + } + header.append('\n'); + out.write(header.toString().getBytes(StandardCharsets.UTF_8)); + + Iterator it = df.limit(effectiveLimit).toLocalIterator(); + StringBuilder rowBuf = new StringBuilder(256); + while (it.hasNext()) { + rowBuf.setLength(0); + Row row = it.next(); + for (int i = 0; i < row.length(); i++) { + if (i > 0) { + rowBuf.append('\t'); + } + Object value = row.get(i); + String cellStr = value == null ? "null" : value.toString(); + if (truncateLength > 0 && cellStr.length() > truncateLength) { + cellStr = cellStr.substring(0, truncateLength) + "..."; + } + rowBuf.append(replaceReservedChars(cellStr)); + } + rowBuf.append('\n'); + out.write(rowBuf.toString().getBytes(StandardCharsets.UTF_8)); + } + out.flush(); + } + + static String replaceReservedChars(String str) { + if (str == null) { + return "null"; + } + return str.replace('\t', ' ').replace('\n', ' '); + } +} diff --git a/spark-connect/src/main/resources/interpreter-setting.json b/spark-connect/src/main/resources/interpreter-setting.json new file mode 100644 index 00000000000..78d68fda936 --- /dev/null +++ b/spark-connect/src/main/resources/interpreter-setting.json @@ -0,0 +1,187 @@ +[ + { + "group": "spark-connect", + "name": "spark-connect", + "className": "org.apache.zeppelin.spark.SparkConnectInterpreter", + "defaultInterpreter": true, + "properties": { + "spark.remote": { + "envName": "SPARK_REMOTE", + "propertyName": "spark.remote", + "defaultValue": "sc://localhost:15002", + "description": "Spark Connect server URI (e.g. sc://localhost:15002 or sc://dataproc-master:8080)", + "type": "string" + }, + "spark.app.name": { + "envName": null, + "propertyName": "spark.app.name", + "defaultValue": "Zeppelin Spark Connect", + "description": "Spark application name", + "type": "string" + }, + "zeppelin.spark.maxResult": { + "envName": null, + "propertyName": "zeppelin.spark.maxResult", + "defaultValue": "1000", + "description": "Max number of rows to display", + "type": "number" + }, + "zeppelin.spark.sql.stacktrace": { + "envName": null, + "propertyName": "zeppelin.spark.sql.stacktrace", + "defaultValue": true, + "description": "Show full exception stacktrace for SQL errors", + "type": "checkbox" + }, + "spark.connect.grpc.maxMessageSize": { + "envName": null, + "propertyName": "spark.connect.grpc.maxMessageSize", + "defaultValue": "134217728", + "description": "Max gRPC message size in bytes (default 128MB). Increase for large result sets.", + "type": "number" + }, + "zeppelin.spark.connect.streamResults": { + "envName": null, + "propertyName": "zeppelin.spark.connect.streamResults", + "defaultValue": false, + "description": "Stream query results row-by-row instead of building full result in memory. Recommended for large result sets.", + "type": "checkbox" + }, + "zeppelin.spark.connect.maxSessionsPerUser": { + "envName": null, + "propertyName": "zeppelin.spark.connect.maxSessionsPerUser", + "defaultValue": "5", + "description": "Maximum number of Spark Connect sessions (SparkSession instances) per user. Each interpreter instance creates one session.", + "type": "number" + } + }, + "editor": { + "language": "sql", + "editOnDblClick": false, + "completionKey": "TAB", + "completionSupport": false + } + }, + { + "group": "spark-connect", + "name": "sql", + "className": "org.apache.zeppelin.spark.SparkConnectSqlInterpreter", + "properties": { + "zeppelin.spark.concurrentSQL": { + "envName": null, + "propertyName": "zeppelin.spark.concurrentSQL", + "defaultValue": true, + "description": "Execute multiple SQL concurrently", + "type": "checkbox" + }, + "zeppelin.spark.concurrentSQL.max": { + "envName": null, + "propertyName": "zeppelin.spark.concurrentSQL.max", + "defaultValue": "10", + "description": "Max concurrent SQL executions", + "type": "number" + }, + "zeppelin.spark.sql.stacktrace": { + "envName": null, + "propertyName": "zeppelin.spark.sql.stacktrace", + "defaultValue": true, + "description": "Show full exception stacktrace for SQL errors", + "type": "checkbox" + } + }, + "editor": { + "language": "sql", + "editOnDblClick": false, + "completionKey": "TAB", + "completionSupport": false + } + }, + { + "group": "spark-connect", + "name": "pyspark", + "className": "org.apache.zeppelin.spark.PySparkConnectInterpreter", + "properties": { + "spark.remote": { + "envName": "SPARK_REMOTE", + "propertyName": "spark.remote", + "defaultValue": "sc://localhost:15002", + "description": "Spark Connect server URI (e.g. sc://localhost:15002 or sc://dataproc-master:8080)", + "type": "string" + }, + "spark.connect.token": { + "envName": "SPARK_CONNECT_TOKEN", + "propertyName": "spark.connect.token", + "defaultValue": "", + "description": "Authentication token for Spark Connect (optional)", + "type": "string" + }, + "spark.connect.use_ssl": { + "envName": "SPARK_CONNECT_USE_SSL", + "propertyName": "spark.connect.use_ssl", + "defaultValue": false, + "description": "Use SSL for Spark Connect connection", + "type": "checkbox" + }, + "spark.app.name": { + "envName": null, + "propertyName": "spark.app.name", + "defaultValue": "Zeppelin Spark Connect", + "description": "Spark application name", + "type": "string" + }, + "spark.pyspark.driver.python": { + "envName": "PYSPARK_DRIVER_PYTHON", + "propertyName": "spark.pyspark.driver.python", + "defaultValue": "", + "description": "Python executable to use for PySpark driver (highest priority). Follows Spark's PySpark detection pattern.", + "type": "string" + }, + "spark.pyspark.python": { + "envName": "PYSPARK_PYTHON", + "propertyName": "spark.pyspark.python", + "defaultValue": "", + "description": "Python executable to use for PySpark (second priority). Can also be set via PYSPARK_PYTHON environment variable.", + "type": "string" + }, + "zeppelin.python": { + "envName": "ZEPPELIN_PYTHON", + "propertyName": "zeppelin.python", + "defaultValue": "", + "description": "Python executable command (optional fallback). If not set, defaults to 'python' and uses system PATH (matching Spark's PySpark behavior).", + "type": "string" + }, + "zeppelin.pyspark.connect.useIPython": { + "envName": null, + "propertyName": "zeppelin.pyspark.connect.useIPython", + "defaultValue": true, + "description": "Use IPython if available", + "type": "checkbox" + }, + "zeppelin.spark.maxResult": { + "envName": null, + "propertyName": "zeppelin.spark.maxResult", + "defaultValue": "1000", + "description": "Max number of rows to display", + "type": "number" + } + }, + "editor": { + "language": "python", + "editOnDblClick": false, + "completionKey": "TAB", + "completionSupport": true + } + }, + { + "group": "spark-connect", + "name": "ipyspark", + "className": "org.apache.zeppelin.spark.IPySparkConnectInterpreter", + "properties": {}, + "editor": { + "language": "python", + "editOnDblClick": false, + "completionKey": "TAB", + "completionSupport": true + } + } +] diff --git a/spark-connect/src/main/resources/python/zeppelin_isparkconnect.py b/spark-connect/src/main/resources/python/zeppelin_isparkconnect.py new file mode 100644 index 00000000000..17a989ddcba --- /dev/null +++ b/spark-connect/src/main/resources/python/zeppelin_isparkconnect.py @@ -0,0 +1,632 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# IPython variant: same wrappers as zeppelin_sparkconnect.py. +# +# Guards against OOM on large collect/toPandas by enforcing configurable +# row limits and providing safe iteration helpers. + +import sys +import warnings + +intp = gateway.entry_point +_jspark = intp.getSparkSession() +_max_result = intp.getMaxResult() + +_COLLECT_LIMIT_DEFAULT = _max_result +_COLLECT_WARN_THRESHOLD = 100000 + + +class Row(tuple): + """Lightweight PySpark-compatible Row that wraps values extracted from Java Row objects.""" + + def __new__(cls, **kwargs): + row = tuple.__new__(cls, kwargs.values()) + row.__dict__['_fields'] = tuple(kwargs.keys()) + return row + + def __repr__(self): + pairs = ", ".join("%s=%r" % (k, v) for k, v in zip(self._fields, self)) + return "Row(%s)" % pairs + + def __getattr__(self, name): + try: + idx = self._fields.index(name) + return self[idx] + except ValueError: + raise AttributeError("Row has no field '%s'" % name) + + def asDict(self): + return dict(zip(self._fields, self)) + + +def _convert_java_row(jrow, col_names): + """Convert a single Java Row to a Python Row.""" + values = {} + for i, col in enumerate(col_names): + val = jrow.get(i) + if hasattr(val, 'getClass'): + val = str(val) + values[col] = val + return Row(**values) + + +def _convert_java_rows(jdf): + """Convert a Java Dataset's collected rows to Python Row objects.""" + fields = jdf.schema().fields() + col_names = [f.name() for f in fields] + jrows = jdf.collectAsList() + return [_convert_java_row(r, col_names) for r in jrows] + + +# --------------------------------------------------------------------------- +# Py4j / type-conversion helpers for createDataFrame and __getattr__ +# --------------------------------------------------------------------------- + +def _is_java_object(obj): + """Check if obj is a Py4j proxy.""" + return hasattr(obj, '_get_object_id') + + +def _is_java_dataset(obj): + """Check if a Py4j proxy represents a Spark Dataset.""" + if not _is_java_object(obj): + return False + try: + return 'Dataset' in obj.getClass().getName() + except Exception: + return False + + +_PYSPARK_TO_JAVA_TYPES = { + 'StringType': 'StringType', + 'IntegerType': 'IntegerType', + 'LongType': 'LongType', + 'DoubleType': 'DoubleType', + 'FloatType': 'FloatType', + 'BooleanType': 'BooleanType', + 'ShortType': 'ShortType', + 'ByteType': 'ByteType', + 'DateType': 'DateType', + 'TimestampType': 'TimestampType', + 'BinaryType': 'BinaryType', + 'NullType': 'NullType', +} + + +def _pyspark_type_to_java(dt): + """Convert a PySpark DataType instance to a Java DataType via Py4j gateway.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + type_name = type(dt).__name__ + if type_name in _PYSPARK_TO_JAVA_TYPES: + return getattr(DataTypes, _PYSPARK_TO_JAVA_TYPES[type_name]) + if type_name == 'DecimalType': + return DataTypes.createDecimalType(dt.precision, dt.scale) + if type_name == 'ArrayType': + return DataTypes.createArrayType( + _pyspark_type_to_java(dt.elementType), dt.containsNull) + if type_name == 'MapType': + return DataTypes.createMapType( + _pyspark_type_to_java(dt.keyType), + _pyspark_type_to_java(dt.valueType), dt.valueContainsNull) + if type_name == 'StructType': + return _pyspark_schema_to_java(dt) + return DataTypes.StringType + + +def _pyspark_schema_to_java(pyspark_schema): + """Convert a PySpark StructType to a Java StructType.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + java_fields = gateway.jvm.java.util.ArrayList() + for field in pyspark_schema.fields: + jtype = _pyspark_type_to_java(field.dataType) + java_fields.add(DataTypes.createStructField( + field.name, jtype, getattr(field, 'nullable', True))) + return DataTypes.createStructType(java_fields) + + +def _infer_java_type(value): + """Infer a Java DataType from a Python value.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + if value is None: + return DataTypes.StringType + if isinstance(value, bool): + return DataTypes.BooleanType + if isinstance(value, int): + return DataTypes.LongType if abs(value) > 2147483647 else DataTypes.IntegerType + if isinstance(value, float): + return DataTypes.DoubleType + return DataTypes.StringType + + +def _resolve_schema(schema, data): + """Resolve any schema representation to a Java StructType.""" + if schema is None: + return _infer_schema(data) + if _is_java_object(schema): + return schema + if hasattr(schema, 'fields') and not _is_java_object(schema): + return _pyspark_schema_to_java(schema) + if isinstance(schema, str): + try: + return gateway.jvm.org.apache.spark.sql.types.StructType.fromDDL(schema) + except Exception: + raise ValueError("Cannot parse DDL schema: %s" % schema) + if isinstance(schema, (list, tuple)) and schema and isinstance(schema[0], str): + return _schema_from_names(schema, data) + raise ValueError("Unsupported schema type: %s" % type(schema).__name__) + + +def _infer_schema(data): + """Infer a Java StructType from the first element of the data.""" + if not data: + raise ValueError("Cannot infer schema from empty data without a schema") + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + first = data[0] + if isinstance(first, Row): + names, values = list(first._fields), list(first) + elif isinstance(first, dict): + names, values = list(first.keys()), list(first.values()) + elif isinstance(first, (list, tuple)): + names = ["_%d" % (i + 1) for i in range(len(first))] + values = list(first) + else: + names, values = ["value"], [first] + java_fields = gateway.jvm.java.util.ArrayList() + for i, name in enumerate(names): + java_fields.add(DataTypes.createStructField( + name, _infer_java_type(values[i] if i < len(values) else None), True)) + return DataTypes.createStructType(java_fields) + + +def _schema_from_names(col_names, data): + """Create a Java StructType from column name list, inferring types from data.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + first = data[0] if data else None + java_fields = gateway.jvm.java.util.ArrayList() + for i, name in enumerate(col_names): + jtype = DataTypes.StringType + if first is not None: + val = None + if isinstance(first, (list, tuple)) and i < len(first): + val = first[i] + elif isinstance(first, dict): + val = first.get(name) + elif isinstance(first, Row) and i < len(first): + val = first[i] + if val is not None: + jtype = _infer_java_type(val) + java_fields.add(DataTypes.createStructField(name, jtype, True)) + return DataTypes.createStructType(java_fields) + + +def _to_java_rows(data, col_names): + """Convert Python data (list of Row/dict/tuple/list) to Java ArrayList.""" + RowFactory = gateway.jvm.org.apache.spark.sql.RowFactory + java_rows = gateway.jvm.java.util.ArrayList() + for item in data: + if isinstance(item, Row): + java_rows.add(RowFactory.create(*list(item))) + elif isinstance(item, dict): + java_rows.add(RowFactory.create(*[item.get(c) for c in col_names])) + elif isinstance(item, (list, tuple)): + java_rows.add(RowFactory.create(*list(item))) + else: + java_rows.add(RowFactory.create(item)) + return java_rows + + +class SparkConnectDataFrame(object): + """Wrapper around a Java Dataset with production-safe data retrieval.""" + + def __init__(self, jdf): + self._jdf = jdf + + def show(self, n=20, truncate=True): + effective_n = min(n, _max_result) + print(intp.formatDataFrame(self._jdf, effective_n)) + + def collect(self, limit=None): + """Collect rows to the driver as Python Row objects. + + Args: + limit: Max rows to collect. Defaults to zeppelin.spark.maxResult. + Pass limit=-1 to collect ALL rows (use with caution). + """ + if limit is None: + limit = _COLLECT_LIMIT_DEFAULT + if limit == -1: + row_count = self._jdf.count() + if row_count > _COLLECT_WARN_THRESHOLD: + warnings.warn( + "Collecting %d rows to driver. This may cause OOM. " + "Consider using .limit() or .toPandas() with a smaller subset." + % row_count) + return _convert_java_rows(self._jdf) + return _convert_java_rows(self._jdf.limit(limit)) + + def take(self, n): + return _convert_java_rows(self._jdf.limit(n)) + + def head(self, n=1): + rows = self.take(n) + if n == 1: + return rows[0] if rows else None + return rows + + def first(self): + return self.head(1) + + def toPandas(self, limit=None): + """Convert to pandas DataFrame. Applies a safety limit. + + Args: + limit: Max rows. Defaults to zeppelin.spark.maxResult. + Pass limit=-1 for all rows (use with caution on large data). + """ + try: + import pandas as pd + except ImportError: + raise ImportError( + "pandas is required for toPandas(). " + "Install it with: pip install pandas") + + if limit is None: + limit = _COLLECT_LIMIT_DEFAULT + if limit == -1: + source_jdf = self._jdf + else: + source_jdf = self._jdf.limit(limit) + + fields = source_jdf.schema().fields() + col_names = [f.name() for f in fields] + jrows = source_jdf.collectAsList() + + if len(jrows) == 0: + return pd.DataFrame(columns=col_names) + + rows_data = [] + for row in jrows: + rows_data.append([row.get(i) for i in range(len(col_names))]) + + return pd.DataFrame(rows_data, columns=col_names) + + def count(self): + return self._jdf.count() + + def limit(self, n): + return SparkConnectDataFrame(self._jdf.limit(n)) + + def filter(self, condition): + return SparkConnectDataFrame(self._jdf.filter(condition)) + + def select(self, *cols): + return SparkConnectDataFrame(self._jdf.select(*cols)) + + def where(self, condition): + return self.filter(condition) + + def groupBy(self, *cols): + return self._jdf.groupBy(*cols) + + def orderBy(self, *cols): + return SparkConnectDataFrame(self._jdf.orderBy(*cols)) + + def sort(self, *cols): + return self.orderBy(*cols) + + def distinct(self): + return SparkConnectDataFrame(self._jdf.distinct()) + + def drop(self, *cols): + return SparkConnectDataFrame(self._jdf.drop(*cols)) + + def dropDuplicates(self, *cols): + if cols: + return SparkConnectDataFrame(self._jdf.dropDuplicates(*cols)) + return SparkConnectDataFrame(self._jdf.dropDuplicates()) + + def join(self, other, on=None, how="inner"): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + if on is not None: + return SparkConnectDataFrame(self._jdf.join(other_jdf, on, how)) + return SparkConnectDataFrame(self._jdf.join(other_jdf)) + + def union(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.union(other_jdf)) + + def withColumn(self, colName, col): + return SparkConnectDataFrame(self._jdf.withColumn(colName, col)) + + def withColumnRenamed(self, existing, new): + return SparkConnectDataFrame(self._jdf.withColumnRenamed(existing, new)) + + def cache(self): + self._jdf.cache() + return self + + def persist(self, storageLevel=None): + if storageLevel: + self._jdf.persist(storageLevel) + else: + self._jdf.persist() + return self + + def unpersist(self, blocking=False): + self._jdf.unpersist(blocking) + return self + + def explain(self, extended=False): + if extended: + self._jdf.explain(True) + else: + self._jdf.explain() + + def createOrReplaceTempView(self, name): + self._jdf.createOrReplaceTempView(name) + + def createTempView(self, name): + self._jdf.createTempView(name) + + def schema(self): + return self._jdf.schema() + + def dtypes(self): + schema = self._jdf.schema() + return [(f.name(), str(f.dataType())) for f in schema.fields()] + + def columns(self): + schema = self._jdf.schema() + return [f.name() for f in schema.fields()] + + def printSchema(self): + print(self._jdf.schema().treeString()) + + def describe(self, *cols): + if cols: + return SparkConnectDataFrame(self._jdf.describe(*cols)) + return SparkConnectDataFrame(self._jdf.describe()) + + def summary(self, *statistics): + if statistics: + return SparkConnectDataFrame(self._jdf.summary(*statistics)) + return SparkConnectDataFrame(self._jdf.summary()) + + def isEmpty(self): + return self._jdf.isEmpty() + + def repartition(self, numPartitions, *cols): + if cols: + return SparkConnectDataFrame(self._jdf.repartition(numPartitions, *cols)) + return SparkConnectDataFrame(self._jdf.repartition(numPartitions)) + + def coalesce(self, numPartitions): + return SparkConnectDataFrame(self._jdf.coalesce(numPartitions)) + + def toDF(self, *cols): + return SparkConnectDataFrame(self._jdf.toDF(*cols)) + + def unionByName(self, other, allowMissingColumns=False): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame( + self._jdf.unionByName(other_jdf, allowMissingColumns)) + + def crossJoin(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.crossJoin(other_jdf)) + + def subtract(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.subtract(other_jdf)) + + def intersect(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.intersect(other_jdf)) + + def exceptAll(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.exceptAll(other_jdf)) + + def sample(self, withReplacement=None, fraction=None, seed=None): + if withReplacement is None and fraction is None: + raise ValueError("fraction must be specified") + if isinstance(withReplacement, float) and fraction is None: + fraction = withReplacement + withReplacement = False + if withReplacement is None: + withReplacement = False + if seed is not None: + return SparkConnectDataFrame( + self._jdf.sample(withReplacement, fraction, seed)) + return SparkConnectDataFrame( + self._jdf.sample(withReplacement, fraction)) + + def dropna(self, how="any", thresh=None, subset=None): + na = self._jdf.na() + if thresh is not None: + if subset: + return SparkConnectDataFrame(na.drop(thresh, subset)) + return SparkConnectDataFrame(na.drop(thresh)) + if subset: + return SparkConnectDataFrame(na.drop(how, subset)) + return SparkConnectDataFrame(na.drop(how)) + + def fillna(self, value, subset=None): + na = self._jdf.na() + if subset: + return SparkConnectDataFrame(na.fill(value, subset)) + return SparkConnectDataFrame(na.fill(value)) + + @property + def write(self): + return self._jdf.write() + + def __repr__(self): + try: + return "SparkConnectDataFrame[%s]" % ", ".join( + f.name() for f in self._jdf.schema().fields()) + except Exception: + return "SparkConnectDataFrame[schema unavailable]" + + def __getattr__(self, name): + attr = getattr(self._jdf, name) + if not callable(attr): + return attr + def _method_wrapper(*args, **kwargs): + result = attr(*args, **kwargs) + if _is_java_dataset(result): + return SparkConnectDataFrame(result) + return result + return _method_wrapper + + def __iter__(self): + """Safe iteration with default limit to prevent OOM.""" + return iter(_convert_java_rows(self._jdf.limit(_COLLECT_LIMIT_DEFAULT))) + + def __len__(self): + return int(self._jdf.count()) + + +class SparkConnectSession(object): + """Wraps the Java SparkSession so that sql() returns a wrapped DataFrame.""" + + def __init__(self, jsession): + self._jsession = jsession + + def sql(self, query): + return SparkConnectDataFrame(self._jsession.sql(query)) + + def table(self, tableName): + return SparkConnectDataFrame(self._jsession.table(tableName)) + + def read(self): + return self._jsession.read() + + def createDataFrame(self, data, schema=None): + """Create a SparkConnectDataFrame from Python data. + + Supports: + - data: list of Row, list of tuples, list of dicts, pandas DataFrame + - schema: PySpark StructType, list of column names, DDL string, + Java StructType (Py4j proxy), or None (infer from data) + """ + try: + import pandas as pd + if isinstance(data, pd.DataFrame): + if schema is None: + schema = list(data.columns) + data = data.values.tolist() + except ImportError: + pass + + if _is_java_object(data): + if schema is None: + return SparkConnectDataFrame(self._jsession.createDataFrame(data)) + java_schema = _resolve_schema(schema, None) + return SparkConnectDataFrame( + self._jsession.createDataFrame(data, java_schema)) + + java_schema = _resolve_schema(schema, data) + col_names = [f.name() for f in java_schema.fields()] + java_rows = _to_java_rows(data, col_names) + return SparkConnectDataFrame( + self._jsession.createDataFrame(java_rows, java_schema)) + + def range(self, start, end=None, step=1, numPartitions=None): + if end is None: + end = start + start = 0 + if numPartitions: + return SparkConnectDataFrame( + self._jsession.range(start, end, step, numPartitions)) + return SparkConnectDataFrame(self._jsession.range(start, end, step)) + + @property + def catalog(self): + return self._jsession.catalog() + + @property + def version(self): + return self._jsession.version() + + @property + def conf(self): + return self._jsession.conf() + + def stop(self): + pass + + def __repr__(self): + return "SparkConnectSession (via Py4j)" + + def __getattr__(self, name): + return getattr(self._jsession, name) + + +def pip_install(*packages): + """Install Python packages into the interpreter pod's environment. + + Usage: + pip_install("requests") + pip_install("requests", "pandas", "numpy==1.24.0") + pip_install("requests>=2.28,<3.0") + """ + import subprocess + import importlib + import site + if not packages: + print("Usage: pip_install('package1', 'package2', ...)") + return + cmd = [sys.executable, "-m", "pip", "install", "--quiet"] + list(packages) + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode == 0: + installed = ", ".join(packages) + print("Successfully installed: %s" % installed) + if result.stdout.strip(): + print(result.stdout.strip()) + importlib.invalidate_caches() + for new_path in site.getsitepackages() + [site.getusersitepackages()]: + if new_path not in sys.path: + sys.path.insert(0, new_path) + else: + print("pip install failed (exit code %d):" % result.returncode) + if result.stderr.strip(): + print(result.stderr.strip()) + if result.stdout.strip(): + print(result.stdout.strip()) + except subprocess.TimeoutExpired: + print("pip install timed out after 300 seconds") + except Exception as e: + print("pip install error: %s" % str(e)) + + +def display(obj, n=20): + """Databricks-compatible display function. + + For SparkConnectDataFrame, renders as Zeppelin %table format. + For other objects, falls back to print(). + """ + if isinstance(obj, SparkConnectDataFrame): + obj.show(n) + else: + print(obj) + + +spark = SparkConnectSession(_jspark) +sqlContext = sqlc = spark diff --git a/spark-connect/src/main/resources/python/zeppelin_sparkconnect.py b/spark-connect/src/main/resources/python/zeppelin_sparkconnect.py new file mode 100644 index 00000000000..46250f44674 --- /dev/null +++ b/spark-connect/src/main/resources/python/zeppelin_sparkconnect.py @@ -0,0 +1,637 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Reuse the Java SparkSession from SparkConnectInterpreter via Py4j. +# Wraps the Java DataFrame so that show()/collect() output goes to +# Python stdout (which Zeppelin captures), matching the SQL interpreter behavior. +# +# Guards against OOM on large collect/toPandas by enforcing configurable +# row limits and providing safe iteration helpers. + +import sys +import warnings + +intp = gateway.entry_point +_jspark = intp.getSparkSession() +_max_result = intp.getMaxResult() + +_COLLECT_LIMIT_DEFAULT = _max_result +_COLLECT_WARN_THRESHOLD = 100000 + + +class Row(tuple): + """Lightweight PySpark-compatible Row that wraps values extracted from Java Row objects.""" + + def __new__(cls, **kwargs): + row = tuple.__new__(cls, kwargs.values()) + row.__dict__['_fields'] = tuple(kwargs.keys()) + return row + + def __repr__(self): + pairs = ", ".join("%s=%r" % (k, v) for k, v in zip(self._fields, self)) + return "Row(%s)" % pairs + + def __getattr__(self, name): + try: + idx = self._fields.index(name) + return self[idx] + except ValueError: + raise AttributeError("Row has no field '%s'" % name) + + def asDict(self): + return dict(zip(self._fields, self)) + + +def _convert_java_row(jrow, col_names): + """Convert a single Java Row to a Python Row.""" + values = {} + for i, col in enumerate(col_names): + val = jrow.get(i) + if hasattr(val, 'getClass'): + val = str(val) + values[col] = val + return Row(**values) + + +def _convert_java_rows(jdf): + """Convert a Java Dataset's collected rows to Python Row objects.""" + fields = jdf.schema().fields() + col_names = [f.name() for f in fields] + jrows = jdf.collectAsList() + return [_convert_java_row(r, col_names) for r in jrows] + + +# --------------------------------------------------------------------------- +# Py4j / type-conversion helpers for createDataFrame and __getattr__ +# --------------------------------------------------------------------------- + +def _is_java_object(obj): + """Check if obj is a Py4j proxy.""" + return hasattr(obj, '_get_object_id') + + +def _is_java_dataset(obj): + """Check if a Py4j proxy represents a Spark Dataset.""" + if not _is_java_object(obj): + return False + try: + return 'Dataset' in obj.getClass().getName() + except Exception: + return False + + +_PYSPARK_TO_JAVA_TYPES = { + 'StringType': 'StringType', + 'IntegerType': 'IntegerType', + 'LongType': 'LongType', + 'DoubleType': 'DoubleType', + 'FloatType': 'FloatType', + 'BooleanType': 'BooleanType', + 'ShortType': 'ShortType', + 'ByteType': 'ByteType', + 'DateType': 'DateType', + 'TimestampType': 'TimestampType', + 'BinaryType': 'BinaryType', + 'NullType': 'NullType', +} + + +def _pyspark_type_to_java(dt): + """Convert a PySpark DataType instance to a Java DataType via Py4j gateway.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + type_name = type(dt).__name__ + if type_name in _PYSPARK_TO_JAVA_TYPES: + return getattr(DataTypes, _PYSPARK_TO_JAVA_TYPES[type_name]) + if type_name == 'DecimalType': + return DataTypes.createDecimalType(dt.precision, dt.scale) + if type_name == 'ArrayType': + return DataTypes.createArrayType( + _pyspark_type_to_java(dt.elementType), dt.containsNull) + if type_name == 'MapType': + return DataTypes.createMapType( + _pyspark_type_to_java(dt.keyType), + _pyspark_type_to_java(dt.valueType), dt.valueContainsNull) + if type_name == 'StructType': + return _pyspark_schema_to_java(dt) + return DataTypes.StringType + + +def _pyspark_schema_to_java(pyspark_schema): + """Convert a PySpark StructType to a Java StructType.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + java_fields = gateway.jvm.java.util.ArrayList() + for field in pyspark_schema.fields: + jtype = _pyspark_type_to_java(field.dataType) + java_fields.add(DataTypes.createStructField( + field.name, jtype, getattr(field, 'nullable', True))) + return DataTypes.createStructType(java_fields) + + +def _infer_java_type(value): + """Infer a Java DataType from a Python value.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + if value is None: + return DataTypes.StringType + if isinstance(value, bool): + return DataTypes.BooleanType + if isinstance(value, int): + return DataTypes.LongType if abs(value) > 2147483647 else DataTypes.IntegerType + if isinstance(value, float): + return DataTypes.DoubleType + return DataTypes.StringType + + +def _resolve_schema(schema, data): + """Resolve any schema representation to a Java StructType.""" + if schema is None: + return _infer_schema(data) + if _is_java_object(schema): + return schema + if hasattr(schema, 'fields') and not _is_java_object(schema): + return _pyspark_schema_to_java(schema) + if isinstance(schema, str): + try: + return gateway.jvm.org.apache.spark.sql.types.StructType.fromDDL(schema) + except Exception: + raise ValueError("Cannot parse DDL schema: %s" % schema) + if isinstance(schema, (list, tuple)) and schema and isinstance(schema[0], str): + return _schema_from_names(schema, data) + raise ValueError("Unsupported schema type: %s" % type(schema).__name__) + + +def _infer_schema(data): + """Infer a Java StructType from the first element of the data.""" + if not data: + raise ValueError("Cannot infer schema from empty data without a schema") + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + first = data[0] + if isinstance(first, Row): + names, values = list(first._fields), list(first) + elif isinstance(first, dict): + names, values = list(first.keys()), list(first.values()) + elif isinstance(first, (list, tuple)): + names = ["_%d" % (i + 1) for i in range(len(first))] + values = list(first) + else: + names, values = ["value"], [first] + java_fields = gateway.jvm.java.util.ArrayList() + for i, name in enumerate(names): + java_fields.add(DataTypes.createStructField( + name, _infer_java_type(values[i] if i < len(values) else None), True)) + return DataTypes.createStructType(java_fields) + + +def _schema_from_names(col_names, data): + """Create a Java StructType from column name list, inferring types from data.""" + DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes + first = data[0] if data else None + java_fields = gateway.jvm.java.util.ArrayList() + for i, name in enumerate(col_names): + jtype = DataTypes.StringType + if first is not None: + val = None + if isinstance(first, (list, tuple)) and i < len(first): + val = first[i] + elif isinstance(first, dict): + val = first.get(name) + elif isinstance(first, Row) and i < len(first): + val = first[i] + if val is not None: + jtype = _infer_java_type(val) + java_fields.add(DataTypes.createStructField(name, jtype, True)) + return DataTypes.createStructType(java_fields) + + +def _to_java_rows(data, col_names): + """Convert Python data (list of Row/dict/tuple/list) to Java ArrayList.""" + RowFactory = gateway.jvm.org.apache.spark.sql.RowFactory + java_rows = gateway.jvm.java.util.ArrayList() + for item in data: + if isinstance(item, Row): + java_rows.add(RowFactory.create(*list(item))) + elif isinstance(item, dict): + java_rows.add(RowFactory.create(*[item.get(c) for c in col_names])) + elif isinstance(item, (list, tuple)): + java_rows.add(RowFactory.create(*list(item))) + else: + java_rows.add(RowFactory.create(item)) + return java_rows + + +class SparkConnectDataFrame(object): + """Wrapper around a Java Dataset with production-safe data retrieval.""" + + def __init__(self, jdf): + self._jdf = jdf + + def show(self, n=20, truncate=True): + effective_n = min(n, _max_result) + print(intp.formatDataFrame(self._jdf, effective_n)) + + def collect(self, limit=None): + """Collect rows to the driver as Python Row objects. + + Args: + limit: Max rows to collect. Defaults to zeppelin.spark.maxResult. + Pass limit=-1 to collect ALL rows (use with caution). + """ + if limit is None: + limit = _COLLECT_LIMIT_DEFAULT + if limit == -1: + row_count = self._jdf.count() + if row_count > _COLLECT_WARN_THRESHOLD: + warnings.warn( + "Collecting %d rows to driver. This may cause OOM. " + "Consider using .limit() or .toPandas() with a smaller subset." + % row_count) + return _convert_java_rows(self._jdf) + return _convert_java_rows(self._jdf.limit(limit)) + + def take(self, n): + return _convert_java_rows(self._jdf.limit(n)) + + def head(self, n=1): + rows = self.take(n) + if n == 1: + return rows[0] if rows else None + return rows + + def first(self): + return self.head(1) + + def toPandas(self, limit=None): + """Convert to pandas DataFrame. Applies a safety limit. + + Tries to use pyarrow for efficient serialization if available, + otherwise falls back to row-by-row conversion through Py4j. + + Args: + limit: Max rows. Defaults to zeppelin.spark.maxResult. + Pass limit=-1 for all rows (use with caution on large data). + """ + try: + import pandas as pd + except ImportError: + raise ImportError( + "pandas is required for toPandas(). " + "Install it with: pip install pandas") + + if limit is None: + limit = _COLLECT_LIMIT_DEFAULT + if limit == -1: + source_jdf = self._jdf + else: + source_jdf = self._jdf.limit(limit) + + fields = source_jdf.schema().fields() + col_names = [f.name() for f in fields] + jrows = source_jdf.collectAsList() + + if len(jrows) == 0: + return pd.DataFrame(columns=col_names) + + rows_data = [] + for row in jrows: + rows_data.append([row.get(i) for i in range(len(col_names))]) + + return pd.DataFrame(rows_data, columns=col_names) + + def count(self): + return self._jdf.count() + + def limit(self, n): + return SparkConnectDataFrame(self._jdf.limit(n)) + + def filter(self, condition): + return SparkConnectDataFrame(self._jdf.filter(condition)) + + def select(self, *cols): + return SparkConnectDataFrame(self._jdf.select(*cols)) + + def where(self, condition): + return self.filter(condition) + + def groupBy(self, *cols): + return self._jdf.groupBy(*cols) + + def orderBy(self, *cols): + return SparkConnectDataFrame(self._jdf.orderBy(*cols)) + + def sort(self, *cols): + return self.orderBy(*cols) + + def distinct(self): + return SparkConnectDataFrame(self._jdf.distinct()) + + def drop(self, *cols): + return SparkConnectDataFrame(self._jdf.drop(*cols)) + + def dropDuplicates(self, *cols): + if cols: + return SparkConnectDataFrame(self._jdf.dropDuplicates(*cols)) + return SparkConnectDataFrame(self._jdf.dropDuplicates()) + + def join(self, other, on=None, how="inner"): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + if on is not None: + return SparkConnectDataFrame(self._jdf.join(other_jdf, on, how)) + return SparkConnectDataFrame(self._jdf.join(other_jdf)) + + def union(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.union(other_jdf)) + + def withColumn(self, colName, col): + return SparkConnectDataFrame(self._jdf.withColumn(colName, col)) + + def withColumnRenamed(self, existing, new): + return SparkConnectDataFrame(self._jdf.withColumnRenamed(existing, new)) + + def cache(self): + self._jdf.cache() + return self + + def persist(self, storageLevel=None): + if storageLevel: + self._jdf.persist(storageLevel) + else: + self._jdf.persist() + return self + + def unpersist(self, blocking=False): + self._jdf.unpersist(blocking) + return self + + def explain(self, extended=False): + if extended: + self._jdf.explain(True) + else: + self._jdf.explain() + + def createOrReplaceTempView(self, name): + self._jdf.createOrReplaceTempView(name) + + def createTempView(self, name): + self._jdf.createTempView(name) + + def schema(self): + return self._jdf.schema() + + def dtypes(self): + schema = self._jdf.schema() + return [(f.name(), str(f.dataType())) for f in schema.fields()] + + def columns(self): + schema = self._jdf.schema() + return [f.name() for f in schema.fields()] + + def printSchema(self): + print(self._jdf.schema().treeString()) + + def describe(self, *cols): + if cols: + return SparkConnectDataFrame(self._jdf.describe(*cols)) + return SparkConnectDataFrame(self._jdf.describe()) + + def summary(self, *statistics): + if statistics: + return SparkConnectDataFrame(self._jdf.summary(*statistics)) + return SparkConnectDataFrame(self._jdf.summary()) + + def isEmpty(self): + return self._jdf.isEmpty() + + def __repr__(self): + try: + return "SparkConnectDataFrame[%s]" % ", ".join( + f.name() for f in self._jdf.schema().fields()) + except Exception: + return "SparkConnectDataFrame[schema unavailable]" + + def repartition(self, numPartitions, *cols): + if cols: + return SparkConnectDataFrame(self._jdf.repartition(numPartitions, *cols)) + return SparkConnectDataFrame(self._jdf.repartition(numPartitions)) + + def coalesce(self, numPartitions): + return SparkConnectDataFrame(self._jdf.coalesce(numPartitions)) + + def toDF(self, *cols): + return SparkConnectDataFrame(self._jdf.toDF(*cols)) + + def unionByName(self, other, allowMissingColumns=False): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame( + self._jdf.unionByName(other_jdf, allowMissingColumns)) + + def crossJoin(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.crossJoin(other_jdf)) + + def subtract(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.subtract(other_jdf)) + + def intersect(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.intersect(other_jdf)) + + def exceptAll(self, other): + other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other + return SparkConnectDataFrame(self._jdf.exceptAll(other_jdf)) + + def sample(self, withReplacement=None, fraction=None, seed=None): + if withReplacement is None and fraction is None: + raise ValueError("fraction must be specified") + if isinstance(withReplacement, float) and fraction is None: + fraction = withReplacement + withReplacement = False + if withReplacement is None: + withReplacement = False + if seed is not None: + return SparkConnectDataFrame( + self._jdf.sample(withReplacement, fraction, seed)) + return SparkConnectDataFrame( + self._jdf.sample(withReplacement, fraction)) + + def dropna(self, how="any", thresh=None, subset=None): + na = self._jdf.na() + if thresh is not None: + if subset: + return SparkConnectDataFrame(na.drop(thresh, subset)) + return SparkConnectDataFrame(na.drop(thresh)) + if subset: + return SparkConnectDataFrame(na.drop(how, subset)) + return SparkConnectDataFrame(na.drop(how)) + + def fillna(self, value, subset=None): + na = self._jdf.na() + if subset: + return SparkConnectDataFrame(na.fill(value, subset)) + return SparkConnectDataFrame(na.fill(value)) + + @property + def write(self): + return self._jdf.write() + + def __getattr__(self, name): + attr = getattr(self._jdf, name) + if not callable(attr): + return attr + def _method_wrapper(*args, **kwargs): + result = attr(*args, **kwargs) + if _is_java_dataset(result): + return SparkConnectDataFrame(result) + return result + return _method_wrapper + + def __iter__(self): + """Safe iteration with default limit to prevent OOM.""" + return iter(_convert_java_rows(self._jdf.limit(_COLLECT_LIMIT_DEFAULT))) + + def __len__(self): + return int(self._jdf.count()) + + +class SparkConnectSession(object): + """Wraps the Java SparkSession so that sql() returns a wrapped DataFrame.""" + + def __init__(self, jsession): + self._jsession = jsession + + def sql(self, query): + return SparkConnectDataFrame(self._jsession.sql(query)) + + def table(self, tableName): + return SparkConnectDataFrame(self._jsession.table(tableName)) + + def read(self): + return self._jsession.read() + + def createDataFrame(self, data, schema=None): + """Create a SparkConnectDataFrame from Python data. + + Supports: + - data: list of Row, list of tuples, list of dicts, pandas DataFrame + - schema: PySpark StructType, list of column names, DDL string, + Java StructType (Py4j proxy), or None (infer from data) + """ + try: + import pandas as pd + if isinstance(data, pd.DataFrame): + if schema is None: + schema = list(data.columns) + data = data.values.tolist() + except ImportError: + pass + + if _is_java_object(data): + if schema is None: + return SparkConnectDataFrame(self._jsession.createDataFrame(data)) + java_schema = _resolve_schema(schema, None) + return SparkConnectDataFrame( + self._jsession.createDataFrame(data, java_schema)) + + java_schema = _resolve_schema(schema, data) + col_names = [f.name() for f in java_schema.fields()] + java_rows = _to_java_rows(data, col_names) + return SparkConnectDataFrame( + self._jsession.createDataFrame(java_rows, java_schema)) + + def range(self, start, end=None, step=1, numPartitions=None): + if end is None: + end = start + start = 0 + if numPartitions: + return SparkConnectDataFrame( + self._jsession.range(start, end, step, numPartitions)) + return SparkConnectDataFrame(self._jsession.range(start, end, step)) + + @property + def catalog(self): + return self._jsession.catalog() + + @property + def version(self): + return self._jsession.version() + + @property + def conf(self): + return self._jsession.conf() + + def stop(self): + pass + + def __repr__(self): + return "SparkConnectSession (via Py4j)" + + def __getattr__(self, name): + return getattr(self._jsession, name) + + +def pip_install(*packages): + """Install Python packages into the interpreter pod's environment. + + Usage: + pip_install("requests") + pip_install("requests", "pandas", "numpy==1.24.0") + pip_install("requests>=2.28,<3.0") + """ + import subprocess + import importlib + import site + if not packages: + print("Usage: pip_install('package1', 'package2', ...)") + return + cmd = [sys.executable, "-m", "pip", "install", "--quiet"] + list(packages) + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode == 0: + installed = ", ".join(packages) + print("Successfully installed: %s" % installed) + if result.stdout.strip(): + print(result.stdout.strip()) + importlib.invalidate_caches() + for new_path in site.getsitepackages() + [site.getusersitepackages()]: + if new_path not in sys.path: + sys.path.insert(0, new_path) + else: + print("pip install failed (exit code %d):" % result.returncode) + if result.stderr.strip(): + print(result.stderr.strip()) + if result.stdout.strip(): + print(result.stdout.strip()) + except subprocess.TimeoutExpired: + print("pip install timed out after 300 seconds") + except Exception as e: + print("pip install error: %s" % str(e)) + + +def display(obj, n=20): + """Databricks-compatible display function. + + For SparkConnectDataFrame, renders as Zeppelin %table format. + For other objects, falls back to print(). + """ + if isinstance(obj, SparkConnectDataFrame): + obj.show(n) + else: + print(obj) + + +spark = SparkConnectSession(_jspark) +sqlContext = sqlc = spark diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/PySparkConnectInterpreterTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/PySparkConnectInterpreterTest.java new file mode 100644 index 00000000000..a8fe2cc4c52 --- /dev/null +++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/PySparkConnectInterpreterTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +/** + * Integration tests for PySparkConnectInterpreter. + * Requires a running Spark Connect server. + * Set SPARK_CONNECT_TEST_REMOTE env var (e.g. sc://localhost:15002) to enable. + */ +@EnabledIfEnvironmentVariable(named = "SPARK_CONNECT_TEST_REMOTE", matches = ".+") +public class PySparkConnectInterpreterTest { + + private static PySparkConnectInterpreter interpreter; + private static SparkConnectInterpreter sparkConnectInterpreter; + private static InterpreterGroup intpGroup; + + @BeforeAll + public static void setUp() throws Exception { + String remote = System.getenv("SPARK_CONNECT_TEST_REMOTE"); + Properties p = new Properties(); + p.setProperty("spark.remote", remote); + p.setProperty("spark.app.name", "ZeppelinPySparkConnectTest"); + p.setProperty("zeppelin.spark.maxResult", "100"); + p.setProperty("zeppelin.pyspark.connect.useIPython", "false"); + p.setProperty("zeppelin.python", "python"); + + intpGroup = new InterpreterGroup(); + + // Create SparkConnectInterpreter first (required dependency) + sparkConnectInterpreter = new SparkConnectInterpreter(p); + sparkConnectInterpreter.setInterpreterGroup(intpGroup); + intpGroup.put("session_1", new LinkedList()); + intpGroup.get("session_1").add(sparkConnectInterpreter); + sparkConnectInterpreter.open(); + + // Create PySparkConnectInterpreter + interpreter = new PySparkConnectInterpreter(p); + interpreter.setInterpreterGroup(intpGroup); + intpGroup.get("session_1").add(interpreter); + interpreter.open(); + } + + @AfterAll + public static void tearDown() throws InterpreterException { + if (interpreter != null) { + interpreter.close(); + } + if (sparkConnectInterpreter != null) { + sparkConnectInterpreter.close(); + } + } + + private static InterpreterContext getInterpreterContext() { + return InterpreterContext.builder() + .setNoteId("noteId") + .setParagraphId("paragraphId") + .setParagraphTitle("title") + .setAngularObjectRegistry(new AngularObjectRegistry(intpGroup.getId(), null)) + .setResourcePool(new LocalResourcePool("id")) + .setInterpreterOut(new InterpreterOutput()) + .setIntpEventClient(mock(RemoteInterpreterEventClient.class)) + .build(); + } + + @Test + void testSparkSessionCreated() { + assertNotNull(interpreter.getSparkConnectInterpreter()); + assertNotNull(interpreter.getSparkConnectInterpreter().getSparkSession()); + } + + @Test + void testSimpleQuery() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "df = spark.sql(\"SELECT 1 AS id, 'hello' AS message\")\ndf.show()", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("id") || output.contains("message") || output.contains("hello"), + "Output should contain query results: " + output); + } + + @Test + void testDataFrameVariable() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "df = spark.sql(\"SELECT 1 AS id, 'test' AS name\")\nprint(type(df))", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("DataFrame") || output.contains("pyspark"), + "Output should indicate DataFrame type: " + output); + } + + @Test + void testDeltaTableQuery() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + // Test the exact query from user + InterpreterResult result = interpreter.interpret( + "df = spark.sql(\"select * from gold.delta_test\")", context); + // This might fail if table doesn't exist, but should not crash + // We check that interpreter handled it gracefully + assertTrue(result.code() == InterpreterResult.Code.SUCCESS + || result.code() == InterpreterResult.Code.ERROR, + "Should handle query execution (success or error): " + result.code()); + } + + @Test + void testSparkVariableAvailable() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "print('Spark session:', type(spark))", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("SparkSession") || output.contains("spark"), + "Spark session should be available: " + output); + } +} diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectInterpreterTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectInterpreterTest.java new file mode 100644 index 00000000000..0b3cbc8213b --- /dev/null +++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectInterpreterTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +/** + * Integration tests for SparkConnectInterpreter. + * Requires a running Spark Connect server. + * Set SPARK_CONNECT_TEST_REMOTE env var (e.g. sc://localhost:15002) to enable. + */ +@EnabledIfEnvironmentVariable(named = "SPARK_CONNECT_TEST_REMOTE", matches = ".+") +public class SparkConnectInterpreterTest { + + private static SparkConnectInterpreter interpreter; + private static InterpreterGroup intpGroup; + + @BeforeAll + public static void setUp() throws Exception { + String remote = System.getenv("SPARK_CONNECT_TEST_REMOTE"); + Properties p = new Properties(); + p.setProperty("spark.remote", remote); + p.setProperty("spark.app.name", "ZeppelinSparkConnectTest"); + p.setProperty("zeppelin.spark.maxResult", "100"); + p.setProperty("zeppelin.spark.sql.stacktrace", "true"); + + intpGroup = new InterpreterGroup(); + interpreter = new SparkConnectInterpreter(p); + interpreter.setInterpreterGroup(intpGroup); + intpGroup.put("session_1", new LinkedList()); + intpGroup.get("session_1").add(interpreter); + + interpreter.open(); + } + + @AfterAll + public static void tearDown() throws InterpreterException { + if (interpreter != null) { + interpreter.close(); + } + } + + private static InterpreterContext getInterpreterContext() { + return InterpreterContext.builder() + .setNoteId("noteId") + .setParagraphId("paragraphId") + .setParagraphTitle("title") + .setAngularObjectRegistry(new AngularObjectRegistry(intpGroup.getId(), null)) + .setResourcePool(new LocalResourcePool("id")) + .setInterpreterOut(new InterpreterOutput()) + .setIntpEventClient(mock(RemoteInterpreterEventClient.class)) + .build(); + } + + @Test + void testSparkSessionCreated() { + assertNotNull(interpreter.getSparkSession()); + } + + @Test + void testSimpleQuery() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "SELECT 1 AS id, 'hello' AS message", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("id")); + assertTrue(output.contains("message")); + assertTrue(output.contains("hello")); + } + + @Test + void testMultipleStatements() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "SELECT 1 AS a; SELECT 2 AS b", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + } + + @Test + void testInvalidSQL() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "SELECT FROM WHERE INVALID", context); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + } + + @Test + void testMaxResultLimit() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + context.getLocalProperties().put("limit", "5"); + InterpreterResult result = interpreter.interpret( + "SELECT id FROM range(100)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + String[] lines = output.split("\n"); + // header + 5 data rows + assertTrue(lines.length <= 7); + } + + @Test + void testDDL() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = interpreter.interpret( + "CREATE OR REPLACE TEMP VIEW test_view AS SELECT 1 AS id, 'test' AS name", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + context = getInterpreterContext(); + result = interpreter.interpret("SELECT * FROM test_view", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("test")); + } + + @Test + void testFormType() { + assertEquals(Interpreter.FormType.SIMPLE, interpreter.getFormType()); + } + + @Test + void testProgress() throws InterpreterException { + InterpreterContext context = getInterpreterContext(); + assertEquals(0, interpreter.getProgress(context)); + } +} diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreterTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreterTest.java new file mode 100644 index 00000000000..acee3ed7a44 --- /dev/null +++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreterTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.apache.zeppelin.display.AngularObjectRegistry; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; +import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient; +import org.apache.zeppelin.resource.LocalResourcePool; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +/** + * Integration tests for SparkConnectSqlInterpreter. + * Requires a running Spark Connect server. + * Set SPARK_CONNECT_TEST_REMOTE env var (e.g. sc://localhost:15002) to enable. + */ +@EnabledIfEnvironmentVariable(named = "SPARK_CONNECT_TEST_REMOTE", matches = ".+") +public class SparkConnectSqlInterpreterTest { + + private static SparkConnectInterpreter connectInterpreter; + private static SparkConnectSqlInterpreter sqlInterpreter; + private static InterpreterGroup intpGroup; + + @BeforeAll + public static void setUp() throws Exception { + String remote = System.getenv("SPARK_CONNECT_TEST_REMOTE"); + Properties p = new Properties(); + p.setProperty("spark.remote", remote); + p.setProperty("spark.app.name", "ZeppelinSparkConnectSqlTest"); + p.setProperty("zeppelin.spark.maxResult", "100"); + p.setProperty("zeppelin.spark.concurrentSQL", "true"); + p.setProperty("zeppelin.spark.concurrentSQL.max", "10"); + p.setProperty("zeppelin.spark.sql.stacktrace", "true"); + p.setProperty("zeppelin.spark.sql.interpolation", "false"); + + intpGroup = new InterpreterGroup(); + connectInterpreter = new SparkConnectInterpreter(p); + connectInterpreter.setInterpreterGroup(intpGroup); + + sqlInterpreter = new SparkConnectSqlInterpreter(p); + sqlInterpreter.setInterpreterGroup(intpGroup); + + intpGroup.put("session_1", new LinkedList()); + intpGroup.get("session_1").add(connectInterpreter); + intpGroup.get("session_1").add(sqlInterpreter); + + connectInterpreter.open(); + sqlInterpreter.open(); + } + + @AfterAll + public static void tearDown() throws InterpreterException { + if (sqlInterpreter != null) { + sqlInterpreter.close(); + } + if (connectInterpreter != null) { + connectInterpreter.close(); + } + } + + private static InterpreterContext getInterpreterContext() { + return InterpreterContext.builder() + .setNoteId("noteId") + .setParagraphId("paragraphId") + .setParagraphTitle("title") + .setAngularObjectRegistry(new AngularObjectRegistry(intpGroup.getId(), null)) + .setResourcePool(new LocalResourcePool("id")) + .setInterpreterOut(new InterpreterOutput()) + .setIntpEventClient(mock(RemoteInterpreterEventClient.class)) + .build(); + } + + @Test + void testSimpleQuery() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = sqlInterpreter.interpret( + "SELECT 1 AS id, 'hello' AS message", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("id")); + assertTrue(output.contains("message")); + assertTrue(output.contains("hello")); + } + + @Test + void testMultipleStatements() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = sqlInterpreter.interpret( + "SELECT 1 AS a; SELECT 2 AS b", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + assertEquals(2, context.out.toInterpreterResultMessage().size()); + } + + @Test + void testInvalidSQL() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = sqlInterpreter.interpret( + "SELECT FROM WHERE INVALID", context); + assertEquals(InterpreterResult.Code.ERROR, result.code()); + assertTrue(context.out.toInterpreterResultMessage().get(0).getData().length() > 0); + } + + @Test + void testMaxResultLimit() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + context.getLocalProperties().put("limit", "3"); + InterpreterResult result = sqlInterpreter.interpret( + "SELECT id FROM range(100)", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + String[] lines = output.split("\n"); + // header + 3 data rows + assertTrue(lines.length <= 5); + } + + @Test + void testCreateAndQuery() throws InterpreterException, IOException { + InterpreterContext context = getInterpreterContext(); + InterpreterResult result = sqlInterpreter.interpret( + "CREATE OR REPLACE TEMP VIEW sql_test AS SELECT 42 AS answer", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + + context = getInterpreterContext(); + result = sqlInterpreter.interpret("SELECT * FROM sql_test", context); + assertEquals(InterpreterResult.Code.SUCCESS, result.code()); + String output = context.out.toInterpreterResultMessage().get(0).getData(); + assertTrue(output.contains("42")); + } + + @Test + void testFormType() { + assertEquals(Interpreter.FormType.SIMPLE, sqlInterpreter.getFormType()); + } +} diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectUtilsTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectUtilsTest.java new file mode 100644 index 00000000000..ebafd1a9bef --- /dev/null +++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectUtilsTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zeppelin.spark; + +import org.junit.jupiter.api.Test; + +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SparkConnectUtilsTest { + + @Test + void testBuildConnectionStringDefault() { + Properties props = new Properties(); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://localhost:15002", result); + } + + @Test + void testBuildConnectionStringCustomRemote() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://spark-server.example.com:15002"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://spark-server.example.com:15002", result); + } + + @Test + void testBuildConnectionStringWithToken() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://spark-server.example.com:15002"); + props.setProperty("spark.connect.token", "my-secret-token"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://spark-server.example.com:15002/;token=my-secret-token", result); + } + + @Test + void testBuildConnectionStringWithTokenAndExistingParams() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://host:15002/;use_ssl=true"); + props.setProperty("spark.connect.token", "tok123"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://host:15002/;use_ssl=true;token=tok123", result); + } + + @Test + void testBuildConnectionStringEmptyToken() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://host:15002"); + props.setProperty("spark.connect.token", ""); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://host:15002", result); + } + + @Test + void testBuildConnectionStringWithUseSsl() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://ranking-cluster-m:8080"); + props.setProperty("spark.connect.use_ssl", "true"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://ranking-cluster-m:8080/;use_ssl=true", result); + } + + @Test + void testBuildConnectionStringWithSslAndToken() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://cluster:8080"); + props.setProperty("spark.connect.use_ssl", "true"); + props.setProperty("spark.connect.token", "abc123"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://cluster:8080/;token=abc123;use_ssl=true", result); + } + + @Test + void testBuildConnectionStringDataprocTunnel() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://localhost:15002"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://localhost:15002", result); + } + + @Test + void testBuildConnectionStringDataprocDirect() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://ranking-cluster-m:8080"); + String result = SparkConnectUtils.buildConnectionString(props); + assertEquals("sc://ranking-cluster-m:8080", result); + } + + @Test + void testBuildConnectionStringWithUserName() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://cluster:8080"); + String result = SparkConnectUtils.buildConnectionString(props, "alice"); + assertEquals("sc://cluster:8080/;user_id=alice", result); + } + + @Test + void testBuildConnectionStringWithUserNameAndToken() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://cluster:8080"); + props.setProperty("spark.connect.token", "tok"); + String result = SparkConnectUtils.buildConnectionString(props, "bob"); + assertEquals("sc://cluster:8080/;token=tok;user_id=bob", result); + } + + @Test + void testBuildConnectionStringUserIdAlreadyInUrl() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://cluster:8080/;user_id=preexisting"); + String result = SparkConnectUtils.buildConnectionString(props, "alice"); + assertEquals("sc://cluster:8080/;user_id=preexisting", result); + } + + @Test + void testBuildConnectionStringNullUserName() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://cluster:8080"); + String result = SparkConnectUtils.buildConnectionString(props, null); + assertEquals("sc://cluster:8080", result); + } + + @Test + void testBuildConnectionStringBlankUserName() { + Properties props = new Properties(); + props.setProperty("spark.remote", "sc://cluster:8080"); + String result = SparkConnectUtils.buildConnectionString(props, " "); + assertEquals("sc://cluster:8080", result); + } + + @Test + void testReplaceReservedChars() { + assertEquals("hello world", SparkConnectUtils.replaceReservedChars("hello\tworld")); + assertEquals("hello world", SparkConnectUtils.replaceReservedChars("hello\nworld")); + assertEquals("null", SparkConnectUtils.replaceReservedChars(null)); + assertEquals("normal", SparkConnectUtils.replaceReservedChars("normal")); + assertEquals("a b c", SparkConnectUtils.replaceReservedChars("a\tb\nc")); + } +}