diff --git a/pom.xml b/pom.xml
index 772df8bc6da..65a529a5c24 100644
--- a/pom.xml
+++ b/pom.xml
@@ -62,6 +62,7 @@
zeppelin-jupyter-interpreter-shaded
groovy
spark
+ spark-connect
spark-submit
submarine
markdown
diff --git a/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml b/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml
index 09ed9a39013..34e658fc2c8 100644
--- a/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml
+++ b/scripts/docker/zeppelin-interpreter/env_python_3_with_R.yml
@@ -4,29 +4,55 @@ channels:
- defaults
dependencies:
- python >=3.9,<3.10
- - pyspark=3.3.2
+ - pyspark=3.5
- pycodestyle
- - scipy
+ # --- Core data libraries ---
+ - pandas
- numpy
+ - scipy
+ - pyarrow
+ # --- Spark Connect protocol ---
- grpcio
- protobuf
+ # --- HTTP / networking ---
+ - requests
+ - urllib3
+ # --- File format support ---
+ - openpyxl
+ - xlrd
+ - pyyaml
+ - tabulate
+ # --- GCP access ---
+ - google-cloud-storage
+ - google-auth
+ - gcsfs
+ # --- Visualization ---
+ - matplotlib
+ - seaborn
+ - plotly
+ - plotnine
+ - altair
+ - vega_datasets
+ - hvplot
+ # --- SQL on pandas ---
- pandasql
+ # --- ML ---
+ - scikit-learn
+ - xgboost
+ # --- IPython / kernel ---
- ipython
- ipykernel
- jupyter_client
- - hvplot
- - plotnine
- - seaborn
+ # --- Data connectors ---
- intake
- intake-parquet
- intake-xarray
- - altair
- - vega_datasets
- - plotly
+ # --- pip-only packages ---
- pip
- pip:
- # works for regular pip packages
- bkzep==0.6.1
+ - delta-spark==3.2.1
+ # --- R support ---
- r-base=3
- r-data.table
- r-evaluate
diff --git a/spark-connect/pom.xml b/spark-connect/pom.xml
new file mode 100644
index 00000000000..cb2a5bd9dc3
--- /dev/null
+++ b/spark-connect/pom.xml
@@ -0,0 +1,130 @@
+
+
+
+
+ 4.0.0
+
+
+ zeppelin-interpreter-parent
+ org.apache.zeppelin
+ 0.11.2
+ ../zeppelin-interpreter-parent/pom.xml
+
+
+ spark-connect-interpreter
+ jar
+ Zeppelin: Spark Connect Interpreter
+ Zeppelin Spark Connect support via gRPC client
+
+
+ spark-connect
+ 3.5.3
+ 2.12
+
+
+
+
+ org.apache.spark
+ spark-connect-client-jvm_${spark.scala.binary.version}
+ ${spark.connect.version}
+
+
+
+ org.apache.zeppelin
+ zeppelin-python
+ ${project.version}
+
+
+
+ org.apache.commons
+ commons-lang3
+
+
+
+ org.mockito
+ mockito-core
+ test
+
+
+
+
+
+
+ maven-resources-plugin
+
+
+ maven-shade-plugin
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+ reference.conf
+
+
+
+
+ org.apache.zeppelin:zeppelin-interpreter-shaded
+
+
+
+
+ io.netty
+ org.apache.zeppelin.spark.connect.io.netty
+
+
+ com.google
+ org.apache.zeppelin.spark.connect.com.google
+
+
+ io.grpc
+ org.apache.zeppelin.spark.connect.io.grpc
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+
+
+
+
+
+
+ spark-connect-3.5
+
+ true
+
+
+ 3.5.3
+
+
+
+
+
diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/IPySparkConnectInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/IPySparkConnectInterpreter.java
new file mode 100644
index 00000000000..d6adef101e6
--- /dev/null
+++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/IPySparkConnectInterpreter.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.python.IPythonInterpreter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Properties;
+
+/**
+ * PySpark Connect Interpreter which uses IPython underlying.
+ * Reuses the Java SparkSession from SparkConnectInterpreter via Py4j.
+ */
+public class IPySparkConnectInterpreter extends IPythonInterpreter {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(IPySparkConnectInterpreter.class);
+
+ private SparkConnectInterpreter sparkConnectInterpreter;
+ private PySparkConnectInterpreter pySparkConnectInterpreter;
+ private boolean opened = false;
+ private InterpreterContext curIntpContext;
+
+ public IPySparkConnectInterpreter(Properties property) {
+ super(property);
+ }
+
+ @Override
+ public synchronized void open() throws InterpreterException {
+ if (opened) {
+ return;
+ }
+
+ this.sparkConnectInterpreter =
+ getInterpreterInTheSameSessionByClassName(SparkConnectInterpreter.class);
+ this.pySparkConnectInterpreter =
+ getInterpreterInTheSameSessionByClassName(PySparkConnectInterpreter.class, false);
+
+ sparkConnectInterpreter.open();
+
+ setProperty("zeppelin.python", pySparkConnectInterpreter.getPythonExec());
+ setUseBuiltinPy4j(true);
+ setAdditionalPythonInitFile("python/zeppelin_isparkconnect.py");
+ super.open();
+ opened = true;
+ }
+
+ @Override
+ public org.apache.zeppelin.interpreter.InterpreterResult interpret(String st,
+ InterpreterContext context) throws InterpreterException {
+ InterpreterContext.set(context);
+ this.curIntpContext = context;
+ String setInptContextStmt = "intp.setInterpreterContextInPython()";
+ org.apache.zeppelin.interpreter.InterpreterResult result =
+ super.interpret(setInptContextStmt, context);
+ if (result.code().equals(org.apache.zeppelin.interpreter.InterpreterResult.Code.ERROR)) {
+ return new org.apache.zeppelin.interpreter.InterpreterResult(
+ org.apache.zeppelin.interpreter.InterpreterResult.Code.ERROR,
+ "Fail to setCurIntpContext");
+ }
+
+ return super.interpret(st, context);
+ }
+
+ public void setInterpreterContextInPython() {
+ InterpreterContext.set(curIntpContext);
+ }
+
+ public SparkSession getSparkSession() {
+ if (sparkConnectInterpreter != null) {
+ return sparkConnectInterpreter.getSparkSession();
+ }
+ return null;
+ }
+
+ @Override
+ public void cancel(InterpreterContext context) throws InterpreterException {
+ super.cancel(context);
+ if (sparkConnectInterpreter != null) {
+ sparkConnectInterpreter.cancel(context);
+ }
+ }
+
+ @Override
+ public void close() throws InterpreterException {
+ LOGGER.info("Close IPySparkConnectInterpreter (opened={})", opened);
+ try {
+ super.close();
+ } finally {
+ opened = false;
+ sparkConnectInterpreter = null;
+ pySparkConnectInterpreter = null;
+ LOGGER.info("IPySparkConnectInterpreter closed and state reset — ready for re-open");
+ }
+ }
+
+ @Override
+ public int getProgress(InterpreterContext context) throws InterpreterException {
+ return 0;
+ }
+
+ public int getMaxResult() {
+ if (sparkConnectInterpreter != null) {
+ return sparkConnectInterpreter.getMaxResult();
+ }
+ return 1000;
+ }
+
+ @SuppressWarnings("unchecked")
+ public String formatDataFrame(Object df, int maxResult) {
+ return SparkConnectUtils.showDataFrame((Dataset) df, maxResult);
+ }
+}
diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/NotebookLockManager.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/NotebookLockManager.java
new file mode 100644
index 00000000000..8f39e937cc9
--- /dev/null
+++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/NotebookLockManager.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * Shared utility class for managing notebook-level locks.
+ * Ensures that only one query executes at a time per notebook,
+ * regardless of which interpreter (SparkConnectInterpreter or SparkConnectSqlInterpreter) is used.
+ */
+public class NotebookLockManager {
+
+ // Locks per notebook to ensure one query at a time per notebook
+ private static final ConcurrentHashMap notebookLocks =
+ new ConcurrentHashMap<>();
+
+ /**
+ * Get or create a lock for the specified notebook.
+ * Uses fair locking to ensure FIFO ordering of query execution.
+ *
+ * @param noteId The notebook ID
+ * @return The lock for this notebook
+ */
+ public static ReentrantLock getNotebookLock(String noteId) {
+ return notebookLocks.computeIfAbsent(noteId,
+ k -> new ReentrantLock(true)); // Fair lock for FIFO ordering
+ }
+
+ /**
+ * Remove the lock for a notebook (cleanup when notebook is closed).
+ *
+ * @param noteId The notebook ID
+ */
+ public static void removeNotebookLock(String noteId) {
+ notebookLocks.remove(noteId);
+ }
+
+ /**
+ * Get the number of active notebook locks (for monitoring/debugging).
+ *
+ * @return The number of active locks
+ */
+ public static int getActiveLockCount() {
+ return notebookLocks.size();
+ }
+}
diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/PySparkConnectInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/PySparkConnectInterpreter.java
new file mode 100644
index 00000000000..1b79a4ac164
--- /dev/null
+++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/PySparkConnectInterpreter.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.python.PythonInterpreter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Map;
+import java.util.Properties;
+
+/**
+ * PySpark interpreter for Spark Connect.
+ * Reuses the Java SparkSession from SparkConnectInterpreter (via Py4j)
+ * so that the Python side uses the same 3.5.x-compatible client as the SQL interpreter.
+ */
+public class PySparkConnectInterpreter extends PythonInterpreter {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(PySparkConnectInterpreter.class);
+
+ private SparkConnectInterpreter sparkConnectInterpreter;
+ private InterpreterContext curIntpContext;
+
+ public PySparkConnectInterpreter(Properties property) {
+ super(property);
+ this.useBuiltinPy4j = true;
+ }
+
+ @Override
+ public void open() throws InterpreterException {
+ setProperty("zeppelin.python.useIPython",
+ getProperty("zeppelin.pyspark.connect.useIPython", "true"));
+
+ this.sparkConnectInterpreter =
+ getInterpreterInTheSameSessionByClassName(SparkConnectInterpreter.class);
+
+ // Ensure the Java SparkSession is ready before starting Python
+ sparkConnectInterpreter.open();
+
+ // Log Python executable resolution (matching Spark's behavior)
+ String pythonExec = getPythonExec();
+ LOGGER.info("Python executable resolved: {}", pythonExec);
+
+ // Call super.open() - let PythonInterpreter handle Python process launch
+ // This matches Spark's PySparkInterpreter behavior - no pre-validation
+ super.open();
+
+ if (!useIPython()) {
+ try {
+ bootstrapInterpreter("python/zeppelin_sparkconnect.py");
+ } catch (IOException e) {
+ LOGGER.error("Fail to bootstrap spark connect", e);
+ throw new InterpreterException("Fail to bootstrap spark connect", e);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws InterpreterException {
+ LOGGER.info("Close PySparkConnectInterpreter");
+ super.close();
+ }
+
+ @Override
+ protected org.apache.zeppelin.python.IPythonInterpreter getIPythonInterpreter()
+ throws InterpreterException {
+ return getInterpreterInTheSameSessionByClassName(IPySparkConnectInterpreter.class, false);
+ }
+
+ @Override
+ public org.apache.zeppelin.interpreter.InterpreterResult interpret(String st,
+ InterpreterContext context) throws InterpreterException {
+ curIntpContext = context;
+ return super.interpret(st, context);
+ }
+
+ @Override
+ protected void preCallPython(InterpreterContext context) {
+ callPython(new PythonInterpretRequest(
+ "intp.setInterpreterContextInPython()", false, false));
+ }
+
+ public void setInterpreterContextInPython() {
+ InterpreterContext.set(curIntpContext);
+ }
+
+ @Override
+ protected Map setupPythonEnv() throws IOException {
+ Map env = super.setupPythonEnv();
+
+ // Set PYSPARK_PYTHON environment variable (following Spark's pattern)
+ // This ensures Python subprocesses can find the correct Python executable
+ String pythonExec = getPythonExec();
+ env.put("PYSPARK_PYTHON", pythonExec);
+ LOGGER.info("Set PYSPARK_PYTHON: {}", pythonExec);
+
+ // Set up LD_LIBRARY_PATH for conda installations
+ // This is critical - conda Python binaries depend on libraries in conda/lib
+ setupCondaLibraryPath(env, pythonExec);
+
+ LOGGER.info("LD_LIBRARY_PATH: {}", env.get("LD_LIBRARY_PATH"));
+ return env;
+ }
+
+ /**
+ * Get Python executable following Spark's PySpark detection pattern exactly:
+ * 1. spark.pyspark.driver.python (from Spark Connect properties)
+ * 2. spark.pyspark.python (from Spark Connect properties)
+ * 3. PYSPARK_DRIVER_PYTHON (environment variable)
+ * 4. PYSPARK_PYTHON (environment variable)
+ * 5. zeppelin.python (Zeppelin property) - if set, validate it
+ * 6. Default to "python" (let system PATH handle it, just like Spark does)
+ *
+ * This matches Spark's PySparkInterpreter.getPythonExec() behavior.
+ * Spark defaults to "python" and relies on system PATH - we do the same.
+ */
+ @Override
+ protected String getPythonExec() {
+ // Priority 1: spark.pyspark.driver.python (Spark Connect property)
+ String driverPython = getProperty("spark.pyspark.driver.python", "");
+ if (StringUtils.isNotBlank(driverPython)) {
+ LOGGER.info("Using Python executable from spark.pyspark.driver.python: {}", driverPython);
+ // Don't validate here - let ProcessBuilder fail naturally if invalid
+ // This matches Spark's behavior - it trusts the configuration
+ return driverPython;
+ }
+
+ // Priority 2: spark.pyspark.python (Spark Connect property)
+ String pysparkPython = getProperty("spark.pyspark.python", "");
+ if (StringUtils.isNotBlank(pysparkPython)) {
+ LOGGER.info("Using Python executable from spark.pyspark.python: {}", pysparkPython);
+ return pysparkPython;
+ }
+
+ // Priority 3: PYSPARK_DRIVER_PYTHON (environment variable)
+ String envDriverPython = System.getenv("PYSPARK_DRIVER_PYTHON");
+ if (StringUtils.isNotBlank(envDriverPython)) {
+ LOGGER.info("Using Python executable from PYSPARK_DRIVER_PYTHON: {}", envDriverPython);
+ return envDriverPython;
+ }
+
+ // Priority 4: PYSPARK_PYTHON (environment variable)
+ String envPysparkPython = System.getenv("PYSPARK_PYTHON");
+ if (StringUtils.isNotBlank(envPysparkPython)) {
+ LOGGER.info("Using Python executable from PYSPARK_PYTHON: {}", envPysparkPython);
+ return envPysparkPython;
+ }
+
+ // Priority 5: zeppelin.python (Zeppelin property) - only if explicitly set
+ String zeppelinPython = getProperty("zeppelin.python", "");
+ if (StringUtils.isNotBlank(zeppelinPython)) {
+ LOGGER.info("Using Python executable from zeppelin.python property: {}", zeppelinPython);
+ return zeppelinPython;
+ }
+
+ // Priority 6: Default to "python" (let system PATH handle it, just like Spark)
+ // Spark's PySparkInterpreter defaults to "python" - we do the same
+ // This relies on system PATH to find Python, no explicit path needed
+ LOGGER.info("No Python executable configured, defaulting to 'python' (will use system PATH)");
+ return "python";
+ }
+
+ private void setupCondaLibraryPath(Map env, String pythonExec) {
+ // If python path contains "/conda/", add conda lib to LD_LIBRARY_PATH
+ // This only applies if an explicit conda path is configured
+ if (pythonExec != null && pythonExec.contains("/conda/")) {
+ // Extract conda base path (e.g., /opt/conda/default from /opt/conda/default/bin/python3)
+ int binIndex = pythonExec.indexOf("/bin/");
+ if (binIndex > 0) {
+ String condaBase = pythonExec.substring(0, binIndex);
+ String condaLib = condaBase + "/lib";
+ java.io.File libDir = new java.io.File(condaLib);
+ if (libDir.exists() && libDir.isDirectory()) {
+ String ldLibraryPath = env.getOrDefault("LD_LIBRARY_PATH", "");
+ if (ldLibraryPath.isEmpty()) {
+ env.put("LD_LIBRARY_PATH", condaLib);
+ } else if (!ldLibraryPath.contains(condaLib)) {
+ env.put("LD_LIBRARY_PATH", condaLib + ":" + ldLibraryPath);
+ }
+ LOGGER.info("Added conda lib directory to LD_LIBRARY_PATH: {}", condaLib);
+ }
+ }
+ }
+ // If using "python" from PATH, don't modify LD_LIBRARY_PATH
+ // Let the system handle it - Python should already be configured correctly
+ }
+
+ /**
+ * Exposes the Java SparkSession to the Python process via Py4j gateway.
+ * This is the same session used by SparkConnectSqlInterpreter.
+ */
+ public SparkSession getSparkSession() {
+ if (sparkConnectInterpreter != null) {
+ return sparkConnectInterpreter.getSparkSession();
+ }
+ return null;
+ }
+
+ public SparkConnectInterpreter getSparkConnectInterpreter() {
+ return sparkConnectInterpreter;
+ }
+
+ public int getMaxResult() {
+ if (sparkConnectInterpreter != null) {
+ return sparkConnectInterpreter.getMaxResult();
+ }
+ return Integer.parseInt(getProperty("zeppelin.spark.maxResult", "1000"));
+ }
+
+ @SuppressWarnings("unchecked")
+ public String formatDataFrame(Object df, int maxResult) {
+ return SparkConnectUtils.showDataFrame((Dataset) df, maxResult);
+ }
+}
diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectInterpreter.java
new file mode 100644
index 00000000000..b8ae477a0e8
--- /dev/null
+++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectInterpreter.java
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.zeppelin.interpreter.AbstractInterpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterResult;
+import org.apache.zeppelin.interpreter.InterpreterResult.Code;
+import org.apache.zeppelin.interpreter.ZeppelinContext;
+import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
+import org.apache.zeppelin.interpreter.util.SqlSplitter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * Spark Connect interpreter for Zeppelin.
+ * Connects to a remote Spark cluster via the Spark Connect gRPC protocol.
+ *
+ * Session & Concurrency Model:
+ *
+ * - Max Spark Connect sessions per user is capped
+ * ({@code zeppelin.spark.connect.maxSessionsPerUser}, default 5).
+ * Each interpreter instance creates one session; Zeppelin's binding mode
+ * (per-user / per-note / scoped / isolated) controls how many instances exist.
+ * - Notebooks are unlimited -- users may open as many as they like.
+ * - Within a single notebook only one query executes at a time
+ * (per-notebook fair lock via {@link NotebookLockManager}).
+ *
+ */
+public class SparkConnectInterpreter extends AbstractInterpreter {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(SparkConnectInterpreter.class);
+
+ /** user -> number of active SparkSession instances owned by that user. */
+ private static final ConcurrentHashMap userSessionCount =
+ new ConcurrentHashMap<>();
+
+ private static final int DEFAULT_MAX_SESSIONS_PER_USER = 5;
+
+ private SparkSession sparkSession;
+ private SqlSplitter sqlSplitter;
+ private int maxResult;
+ private String currentUser;
+ private volatile boolean sessionSlotAcquired = false;
+
+ public SparkConnectInterpreter(Properties properties) {
+ super(properties);
+ }
+
+ @Override
+ public synchronized void open() throws InterpreterException {
+ if (sparkSession != null) {
+ LOGGER.warn("open() called but sparkSession is already active — skipping. " +
+ "Call close() first to restart the interpreter.");
+ return;
+ }
+
+ try {
+ currentUser = getUserName();
+ if (StringUtils.isBlank(currentUser)) {
+ currentUser = "anonymous";
+ }
+
+ int maxSessions = Integer.parseInt(
+ getProperty("zeppelin.spark.connect.maxSessionsPerUser",
+ String.valueOf(DEFAULT_MAX_SESSIONS_PER_USER)));
+
+ if (!acquireSessionSlot(currentUser, maxSessions)) {
+ throw new InterpreterException(
+ String.format("User '%s' already has %d active Spark Connect sessions " +
+ "(max %d). Close an existing interpreter before opening a new one.",
+ currentUser, maxSessions, maxSessions));
+ }
+ sessionSlotAcquired = true;
+
+ LOGGER.info("Opening SparkConnectInterpreter for user: {} (session slot {}/{})",
+ currentUser, userSessionCount.getOrDefault(currentUser, 0), maxSessions);
+
+ String remoteUrl = SparkConnectUtils.buildConnectionString(getProperties(), currentUser);
+ LOGGER.info("Connecting to Spark Connect server at: {}",
+ remoteUrl.replaceAll("token=[^;]*", "token=[REDACTED]")
+ .replaceAll("user_id=[^;]*", "user_id=[REDACTED]"));
+
+ // Clear the thread-local active session on the Spark Connect client so that
+ // getOrCreate() creates a fresh remote session rather than reusing the previous
+ // closed one. We intentionally do NOT call clearDefaultSession() because that
+ // is a JVM-global operation and would disrupt other interpreter instances that
+ // are concurrently active in the same process.
+ try {
+ SparkSession.clearActiveSession();
+ LOGGER.info("Cleared thread-local active Spark session (safe for multi-interpreter)");
+ } catch (Exception e) {
+ LOGGER.warn("Could not clear active Spark session (non-fatal): {}", e.getMessage());
+ }
+
+ SparkSession.Builder builder = SparkSession.builder().remote(remoteUrl);
+
+ String appName = getProperty("spark.app.name", "Zeppelin Spark Connect");
+ if (StringUtils.isNotBlank(appName)) {
+ builder.appName(appName);
+ }
+
+ String grpcMaxMsgSize = getProperty(
+ "spark.connect.grpc.maxMessageSize", "134217728");
+ builder.config("spark.connect.grpc.maxMessageSize", grpcMaxMsgSize);
+
+ for (Object key : getProperties().keySet()) {
+ String keyStr = key.toString();
+ String value = getProperties().getProperty(keyStr);
+ if (StringUtils.isNotBlank(value)
+ && keyStr.startsWith("spark.")
+ && !keyStr.equals("spark.remote")
+ && !keyStr.equals("spark.connect.token")
+ && !keyStr.equals("spark.connect.use_ssl")
+ && !keyStr.equals("spark.app.name")
+ && !keyStr.equals("spark.connect.grpc.maxMessageSize")) {
+ builder.config(keyStr, value);
+ }
+ }
+
+ sparkSession = builder.getOrCreate();
+ LOGGER.info("Spark Connect session established for user: {}", currentUser);
+
+ maxResult = Integer.parseInt(getProperty("zeppelin.spark.maxResult", "1000"));
+ sqlSplitter = new SqlSplitter();
+ } catch (InterpreterException ie) {
+ throw ie;
+ } catch (Exception e) {
+ if (sessionSlotAcquired) {
+ releaseSessionSlot(currentUser);
+ sessionSlotAcquired = false;
+ }
+ LOGGER.error("Failed to connect to Spark Connect server", e);
+ throw new InterpreterException("Failed to connect to Spark Connect server: "
+ + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public void close() throws InterpreterException {
+ LOGGER.info("Closing SparkConnectInterpreter for user: {} (sparkSession={})",
+ currentUser, sparkSession != null ? "active" : "null");
+ if (sparkSession != null) {
+ try {
+ sparkSession.close();
+ LOGGER.info("Spark Connect session closed for user: {}", currentUser);
+ } catch (Exception e) {
+ LOGGER.warn("Error closing Spark Connect session", e);
+ } finally {
+ sparkSession = null;
+ }
+ } else {
+ LOGGER.info("close() called but no active sparkSession — nothing to tear down");
+ }
+ if (sessionSlotAcquired) {
+ releaseSessionSlot(currentUser);
+ sessionSlotAcquired = false;
+ }
+ }
+
+ @Override
+ public ZeppelinContext getZeppelinContext() {
+ return null;
+ }
+
+ @Override
+ public InterpreterResult internalInterpret(String st, InterpreterContext context)
+ throws InterpreterException {
+ if (sparkSession == null) {
+ return new InterpreterResult(Code.ERROR,
+ "Spark Connect session is not initialized. Check connection settings.");
+ }
+
+ String noteId = context.getNoteId();
+ if (StringUtils.isBlank(noteId)) {
+ return new InterpreterResult(Code.ERROR,
+ "Note ID is missing from interpreter context.");
+ }
+
+ // Per-notebook lock: only one query at a time inside a notebook
+ ReentrantLock notebookLock = NotebookLockManager.getNotebookLock(noteId);
+ notebookLock.lock();
+ try {
+ List sqls = sqlSplitter.splitSql(st);
+ int limit = Integer.parseInt(context.getLocalProperties().getOrDefault("limit",
+ String.valueOf(maxResult)));
+
+ boolean useStreaming = Boolean.parseBoolean(
+ getProperty("zeppelin.spark.connect.streamResults", "false"));
+
+ String curSql = null;
+ try {
+ for (String sql : sqls) {
+ curSql = sql;
+ if (StringUtils.isBlank(sql)) {
+ continue;
+ }
+ Dataset df = sparkSession.sql(sql);
+ if (useStreaming) {
+ SparkConnectUtils.streamDataFrame(df, limit, context.out);
+ } else {
+ String result = SparkConnectUtils.showDataFrame(df, limit);
+ context.out.write(result);
+ }
+ }
+ context.out.flush();
+ } catch (Exception e) {
+ return handleSqlException(e, curSql, context);
+ }
+
+ return new InterpreterResult(Code.SUCCESS);
+ } finally {
+ notebookLock.unlock();
+ }
+ }
+
+ // ---- session-slot helpers (static, shared across all instances) ----
+
+ /**
+ * Try to claim one session slot for the user.
+ * @return true if a slot was available and claimed
+ */
+ private static synchronized boolean acquireSessionSlot(String user, int maxSessions) {
+ int current = userSessionCount.getOrDefault(user, 0);
+ if (current >= maxSessions) {
+ LOGGER.warn("User {} already has {} active Spark Connect sessions (max {})",
+ user, current, maxSessions);
+ return false;
+ }
+ userSessionCount.put(user, current + 1);
+ LOGGER.info("Acquired session slot for user {}. Active sessions: {}/{}",
+ user, current + 1, maxSessions);
+ return true;
+ }
+
+ /**
+ * Release one session slot for the user.
+ */
+ private static synchronized void releaseSessionSlot(String user) {
+ if (user == null) {
+ return;
+ }
+ int current = userSessionCount.getOrDefault(user, 0);
+ if (current <= 1) {
+ userSessionCount.remove(user);
+ } else {
+ userSessionCount.put(user, current - 1);
+ }
+ LOGGER.info("Released session slot for user {}. Remaining sessions: {}",
+ user, Math.max(0, current - 1));
+ }
+
+ /** Visible for testing. */
+ static int getActiveSessionCount(String user) {
+ return userSessionCount.getOrDefault(user, 0);
+ }
+
+ private InterpreterResult handleSqlException(Exception e, String sql,
+ InterpreterContext context) {
+ try {
+ LOGGER.error("Error executing SQL: {}", sql, e);
+ context.out.write("\nError in SQL: " + sql + "\n");
+ if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace", "true"))) {
+ if (e.getCause() != null) {
+ context.out.write(ExceptionUtils.getStackTrace(e.getCause()));
+ } else {
+ context.out.write(ExceptionUtils.getStackTrace(e));
+ }
+ } else {
+ String msg = e.getCause() != null ? e.getCause().getMessage() : e.getMessage();
+ context.out.write(msg
+ + "\nSet zeppelin.spark.sql.stacktrace = true to see full stacktrace");
+ }
+ context.out.flush();
+ } catch (IOException ex) {
+ LOGGER.error("Failed to write error output", ex);
+ }
+ return new InterpreterResult(Code.ERROR);
+ }
+
+ @Override
+ public void cancel(InterpreterContext context) throws InterpreterException {
+ if (sparkSession != null) {
+ try {
+ sparkSession.interruptAll();
+ } catch (Exception e) {
+ LOGGER.warn("Error interrupting Spark Connect session", e);
+ }
+ }
+ }
+
+ @Override
+ public FormType getFormType() {
+ return FormType.SIMPLE;
+ }
+
+ @Override
+ public int getProgress(InterpreterContext context) throws InterpreterException {
+ return 0;
+ }
+
+ @Override
+ public List completion(String buf, int cursor,
+ InterpreterContext interpreterContext) throws InterpreterException {
+ return new ArrayList<>();
+ }
+
+ public SparkSession getSparkSession() {
+ return sparkSession;
+ }
+
+ public int getMaxResult() {
+ return maxResult;
+ }
+}
diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreter.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreter.java
new file mode 100644
index 00000000000..e642455d9aa
--- /dev/null
+++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreter.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.zeppelin.interpreter.AbstractInterpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterResult;
+import org.apache.zeppelin.interpreter.InterpreterResult.Code;
+import org.apache.zeppelin.interpreter.ZeppelinContext;
+import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
+import org.apache.zeppelin.interpreter.util.SqlSplitter;
+import org.apache.zeppelin.scheduler.Scheduler;
+import org.apache.zeppelin.scheduler.SchedulerFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.locks.ReentrantLock;
+
+/**
+ * Spark Connect SQL interpreter for Zeppelin.
+ * Delegates to SparkConnectInterpreter for the SparkSession, providing
+ * dedicated SQL execution with concurrent query support.
+ */
+public class SparkConnectSqlInterpreter extends AbstractInterpreter {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(SparkConnectSqlInterpreter.class);
+
+ private SparkConnectInterpreter sparkConnectInterpreter;
+ private SqlSplitter sqlSplitter;
+
+ public SparkConnectSqlInterpreter(Properties properties) {
+ super(properties);
+ }
+
+ @Override
+ public void open() throws InterpreterException {
+ this.sparkConnectInterpreter =
+ getInterpreterInTheSameSessionByClassName(SparkConnectInterpreter.class);
+ this.sqlSplitter = new SqlSplitter();
+ }
+
+ @Override
+ public void close() throws InterpreterException {
+ sparkConnectInterpreter = null;
+ }
+
+ @Override
+ protected boolean isInterpolate() {
+ return Boolean.parseBoolean(getProperty("zeppelin.spark.sql.interpolation", "false"));
+ }
+
+ @Override
+ public ZeppelinContext getZeppelinContext() {
+ return null;
+ }
+
+ @Override
+ public InterpreterResult internalInterpret(String st, InterpreterContext context)
+ throws InterpreterException {
+ SparkSession sparkSession = sparkConnectInterpreter.getSparkSession();
+ if (sparkSession == null) {
+ return new InterpreterResult(Code.ERROR,
+ "Spark Connect session is not initialized. Check connection settings.");
+ }
+
+ // Get noteId from context for notebook-level synchronization
+ String noteId = context.getNoteId();
+ if (StringUtils.isBlank(noteId)) {
+ return new InterpreterResult(Code.ERROR,
+ "Note ID is missing from interpreter context.");
+ }
+
+ // Get or create lock for this notebook to ensure sequential execution
+ // This ensures one query at a time per notebook, even with concurrentSQL enabled
+ ReentrantLock notebookLock = NotebookLockManager.getNotebookLock(noteId);
+
+ // Acquire lock to ensure only one query executes at a time for this notebook
+ notebookLock.lock();
+ try {
+ List sqls = sqlSplitter.splitSql(st);
+ int maxResult = Integer.parseInt(context.getLocalProperties().getOrDefault("limit",
+ String.valueOf(sparkConnectInterpreter.getMaxResult())));
+
+ boolean useStreaming = Boolean.parseBoolean(
+ getProperty("zeppelin.spark.connect.streamResults", "false"));
+
+ String curSql = null;
+ try {
+ for (String sql : sqls) {
+ curSql = sql;
+ if (StringUtils.isBlank(sql)) {
+ continue;
+ }
+ Dataset df = sparkSession.sql(sql);
+ if (useStreaming) {
+ SparkConnectUtils.streamDataFrame(df, maxResult, context.out);
+ } else {
+ String result = SparkConnectUtils.showDataFrame(df, maxResult);
+ context.out.write(result);
+ }
+ }
+ context.out.flush();
+ } catch (Exception e) {
+ try {
+ LOGGER.error("Error executing SQL: {}", curSql, e);
+ context.out.write("\nError in SQL: " + curSql + "\n");
+ if (Boolean.parseBoolean(getProperty("zeppelin.spark.sql.stacktrace", "true"))) {
+ if (e.getCause() != null) {
+ context.out.write(ExceptionUtils.getStackTrace(e.getCause()));
+ } else {
+ context.out.write(ExceptionUtils.getStackTrace(e));
+ }
+ } else {
+ String msg = e.getCause() != null ? e.getCause().getMessage() : e.getMessage();
+ context.out.write(msg
+ + "\nSet zeppelin.spark.sql.stacktrace = true to see full stacktrace");
+ }
+ context.out.flush();
+ } catch (IOException ex) {
+ LOGGER.error("Failed to write error output", ex);
+ }
+ return new InterpreterResult(Code.ERROR);
+ }
+
+ return new InterpreterResult(Code.SUCCESS);
+ } finally {
+ notebookLock.unlock();
+ }
+ }
+
+ @Override
+ public void cancel(InterpreterContext context) throws InterpreterException {
+ SparkSession sparkSession = sparkConnectInterpreter.getSparkSession();
+ if (sparkSession != null) {
+ try {
+ sparkSession.interruptAll();
+ } catch (Exception e) {
+ LOGGER.warn("Error interrupting Spark Connect session", e);
+ }
+ }
+ }
+
+ @Override
+ public FormType getFormType() {
+ return FormType.SIMPLE;
+ }
+
+ @Override
+ public int getProgress(InterpreterContext context) throws InterpreterException {
+ return 0;
+ }
+
+ @Override
+ public List completion(String buf, int cursor,
+ InterpreterContext interpreterContext) throws InterpreterException {
+ return new ArrayList<>();
+ }
+
+ @Override
+ public Scheduler getScheduler() {
+ if (concurrentSQL()) {
+ int maxConcurrency = Integer.parseInt(
+ getProperty("zeppelin.spark.concurrentSQL.max", "10"));
+ return SchedulerFactory.singleton().createOrGetParallelScheduler(
+ SparkConnectSqlInterpreter.class.getName() + this.hashCode(), maxConcurrency);
+ } else {
+ try {
+ return getInterpreterInTheSameSessionByClassName(
+ SparkConnectInterpreter.class, false).getScheduler();
+ } catch (InterpreterException e) {
+ throw new RuntimeException("Failed to get scheduler", e);
+ }
+ }
+ }
+
+ private boolean concurrentSQL() {
+ return Boolean.parseBoolean(getProperty("zeppelin.spark.concurrentSQL"));
+ }
+}
diff --git a/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectUtils.java b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectUtils.java
new file mode 100644
index 00000000000..ffad1fbb789
--- /dev/null
+++ b/spark-connect/src/main/java/org/apache/zeppelin/spark/SparkConnectUtils.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Properties;
+
+public class SparkConnectUtils {
+
+ static final int DEFAULT_TRUNCATE_LENGTH = 256;
+
+ private SparkConnectUtils() {
+ }
+
+ /**
+ * Build the Spark Connect connection string from interpreter properties.
+ * Format: sc://hostname:port[/;param1=val1;param2=val2]
+ *
+ * Supports: token, use_ssl, user_id, and any extra params already in the URI.
+ * Examples:
+ * sc://localhost:15002
+ * sc://localhost:15002/;use_ssl=true;token=abc123;user_id=alice
+ * sc://ranking-cluster-m:8080
+ */
+ public static String buildConnectionString(Properties properties) {
+ return buildConnectionString(properties, null);
+ }
+
+ /**
+ * Build the Spark Connect connection string, including user_id so the Spark Connect
+ * server can attribute the session to the correct user in its own UI.
+ *
+ * @param properties interpreter properties
+ * @param userName the authenticated Zeppelin username; ignored if blank
+ */
+ public static String buildConnectionString(Properties properties, String userName) {
+ String remote = properties.getProperty("spark.remote", "sc://localhost:15002");
+ StringBuilder params = new StringBuilder();
+
+ String token = properties.getProperty("spark.connect.token", "");
+ if (StringUtils.isNotBlank(token)) {
+ params.append(";token=").append(token);
+ }
+
+ boolean useSsl = Boolean.parseBoolean(
+ properties.getProperty("spark.connect.use_ssl", "false"));
+ if (useSsl) {
+ params.append(";use_ssl=true");
+ }
+
+ if (StringUtils.isNotBlank(userName) && !remote.contains("user_id=")) {
+ params.append(";user_id=").append(userName);
+ }
+
+ if (params.length() > 0) {
+ if (remote.contains(";")) {
+ remote = remote + params;
+ } else {
+ remote = remote + "/" + params;
+ }
+ }
+ return remote;
+ }
+
+ /**
+ * Convert a Dataset to Zeppelin's %table format string.
+ * Applies limit before collecting to prevent OOM on the driver.
+ * Truncates cell values to avoid excessively wide output.
+ */
+ public static String showDataFrame(Dataset df, int maxResult) {
+ return showDataFrame(df, maxResult, DEFAULT_TRUNCATE_LENGTH);
+ }
+
+ public static String showDataFrame(Dataset df, int maxResult, int truncateLength) {
+ StructType schema = df.schema();
+ StructField[] fields = schema.fields();
+
+ int effectiveLimit = Math.max(1, Math.min(maxResult, 100_000));
+
+ int estimatedRowSize = Math.max(fields.length * 20, 100);
+ int estimatedTotalBytes = estimatedRowSize * effectiveLimit;
+ StringBuilder sb = new StringBuilder(Math.min(estimatedTotalBytes, 10 * 1024 * 1024));
+ sb.append("%table ");
+
+ for (int i = 0; i < fields.length; i++) {
+ if (i > 0) {
+ sb.append('\t');
+ }
+ sb.append(replaceReservedChars(fields[i].name()));
+ }
+ sb.append('\n');
+
+ List rows = df.limit(effectiveLimit).collectAsList();
+ for (Row row : rows) {
+ for (int i = 0; i < row.length(); i++) {
+ if (i > 0) {
+ sb.append('\t');
+ }
+ Object value = row.get(i);
+ String cellStr = value == null ? "null" : value.toString();
+ if (truncateLength > 0 && cellStr.length() > truncateLength) {
+ cellStr = cellStr.substring(0, truncateLength) + "...";
+ }
+ sb.append(replaceReservedChars(cellStr));
+ }
+ sb.append('\n');
+ }
+
+ return sb.toString();
+ }
+
+ /**
+ * Stream a Dataset as Zeppelin %table format directly to an OutputStream,
+ * avoiding building the entire result in memory.
+ * Preferred for large result sets.
+ */
+ public static void streamDataFrame(Dataset df, int maxResult, OutputStream out)
+ throws IOException {
+ streamDataFrame(df, maxResult, DEFAULT_TRUNCATE_LENGTH, out);
+ }
+
+ public static void streamDataFrame(Dataset df, int maxResult,
+ int truncateLength, OutputStream out) throws IOException {
+ StructType schema = df.schema();
+ StructField[] fields = schema.fields();
+
+ int effectiveLimit = Math.max(1, Math.min(maxResult, 100_000));
+
+ StringBuilder header = new StringBuilder("%table ");
+ for (int i = 0; i < fields.length; i++) {
+ if (i > 0) {
+ header.append('\t');
+ }
+ header.append(replaceReservedChars(fields[i].name()));
+ }
+ header.append('\n');
+ out.write(header.toString().getBytes(StandardCharsets.UTF_8));
+
+ Iterator it = df.limit(effectiveLimit).toLocalIterator();
+ StringBuilder rowBuf = new StringBuilder(256);
+ while (it.hasNext()) {
+ rowBuf.setLength(0);
+ Row row = it.next();
+ for (int i = 0; i < row.length(); i++) {
+ if (i > 0) {
+ rowBuf.append('\t');
+ }
+ Object value = row.get(i);
+ String cellStr = value == null ? "null" : value.toString();
+ if (truncateLength > 0 && cellStr.length() > truncateLength) {
+ cellStr = cellStr.substring(0, truncateLength) + "...";
+ }
+ rowBuf.append(replaceReservedChars(cellStr));
+ }
+ rowBuf.append('\n');
+ out.write(rowBuf.toString().getBytes(StandardCharsets.UTF_8));
+ }
+ out.flush();
+ }
+
+ static String replaceReservedChars(String str) {
+ if (str == null) {
+ return "null";
+ }
+ return str.replace('\t', ' ').replace('\n', ' ');
+ }
+}
diff --git a/spark-connect/src/main/resources/interpreter-setting.json b/spark-connect/src/main/resources/interpreter-setting.json
new file mode 100644
index 00000000000..78d68fda936
--- /dev/null
+++ b/spark-connect/src/main/resources/interpreter-setting.json
@@ -0,0 +1,187 @@
+[
+ {
+ "group": "spark-connect",
+ "name": "spark-connect",
+ "className": "org.apache.zeppelin.spark.SparkConnectInterpreter",
+ "defaultInterpreter": true,
+ "properties": {
+ "spark.remote": {
+ "envName": "SPARK_REMOTE",
+ "propertyName": "spark.remote",
+ "defaultValue": "sc://localhost:15002",
+ "description": "Spark Connect server URI (e.g. sc://localhost:15002 or sc://dataproc-master:8080)",
+ "type": "string"
+ },
+ "spark.app.name": {
+ "envName": null,
+ "propertyName": "spark.app.name",
+ "defaultValue": "Zeppelin Spark Connect",
+ "description": "Spark application name",
+ "type": "string"
+ },
+ "zeppelin.spark.maxResult": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.maxResult",
+ "defaultValue": "1000",
+ "description": "Max number of rows to display",
+ "type": "number"
+ },
+ "zeppelin.spark.sql.stacktrace": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.sql.stacktrace",
+ "defaultValue": true,
+ "description": "Show full exception stacktrace for SQL errors",
+ "type": "checkbox"
+ },
+ "spark.connect.grpc.maxMessageSize": {
+ "envName": null,
+ "propertyName": "spark.connect.grpc.maxMessageSize",
+ "defaultValue": "134217728",
+ "description": "Max gRPC message size in bytes (default 128MB). Increase for large result sets.",
+ "type": "number"
+ },
+ "zeppelin.spark.connect.streamResults": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.connect.streamResults",
+ "defaultValue": false,
+ "description": "Stream query results row-by-row instead of building full result in memory. Recommended for large result sets.",
+ "type": "checkbox"
+ },
+ "zeppelin.spark.connect.maxSessionsPerUser": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.connect.maxSessionsPerUser",
+ "defaultValue": "5",
+ "description": "Maximum number of Spark Connect sessions (SparkSession instances) per user. Each interpreter instance creates one session.",
+ "type": "number"
+ }
+ },
+ "editor": {
+ "language": "sql",
+ "editOnDblClick": false,
+ "completionKey": "TAB",
+ "completionSupport": false
+ }
+ },
+ {
+ "group": "spark-connect",
+ "name": "sql",
+ "className": "org.apache.zeppelin.spark.SparkConnectSqlInterpreter",
+ "properties": {
+ "zeppelin.spark.concurrentSQL": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.concurrentSQL",
+ "defaultValue": true,
+ "description": "Execute multiple SQL concurrently",
+ "type": "checkbox"
+ },
+ "zeppelin.spark.concurrentSQL.max": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.concurrentSQL.max",
+ "defaultValue": "10",
+ "description": "Max concurrent SQL executions",
+ "type": "number"
+ },
+ "zeppelin.spark.sql.stacktrace": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.sql.stacktrace",
+ "defaultValue": true,
+ "description": "Show full exception stacktrace for SQL errors",
+ "type": "checkbox"
+ }
+ },
+ "editor": {
+ "language": "sql",
+ "editOnDblClick": false,
+ "completionKey": "TAB",
+ "completionSupport": false
+ }
+ },
+ {
+ "group": "spark-connect",
+ "name": "pyspark",
+ "className": "org.apache.zeppelin.spark.PySparkConnectInterpreter",
+ "properties": {
+ "spark.remote": {
+ "envName": "SPARK_REMOTE",
+ "propertyName": "spark.remote",
+ "defaultValue": "sc://localhost:15002",
+ "description": "Spark Connect server URI (e.g. sc://localhost:15002 or sc://dataproc-master:8080)",
+ "type": "string"
+ },
+ "spark.connect.token": {
+ "envName": "SPARK_CONNECT_TOKEN",
+ "propertyName": "spark.connect.token",
+ "defaultValue": "",
+ "description": "Authentication token for Spark Connect (optional)",
+ "type": "string"
+ },
+ "spark.connect.use_ssl": {
+ "envName": "SPARK_CONNECT_USE_SSL",
+ "propertyName": "spark.connect.use_ssl",
+ "defaultValue": false,
+ "description": "Use SSL for Spark Connect connection",
+ "type": "checkbox"
+ },
+ "spark.app.name": {
+ "envName": null,
+ "propertyName": "spark.app.name",
+ "defaultValue": "Zeppelin Spark Connect",
+ "description": "Spark application name",
+ "type": "string"
+ },
+ "spark.pyspark.driver.python": {
+ "envName": "PYSPARK_DRIVER_PYTHON",
+ "propertyName": "spark.pyspark.driver.python",
+ "defaultValue": "",
+ "description": "Python executable to use for PySpark driver (highest priority). Follows Spark's PySpark detection pattern.",
+ "type": "string"
+ },
+ "spark.pyspark.python": {
+ "envName": "PYSPARK_PYTHON",
+ "propertyName": "spark.pyspark.python",
+ "defaultValue": "",
+ "description": "Python executable to use for PySpark (second priority). Can also be set via PYSPARK_PYTHON environment variable.",
+ "type": "string"
+ },
+ "zeppelin.python": {
+ "envName": "ZEPPELIN_PYTHON",
+ "propertyName": "zeppelin.python",
+ "defaultValue": "",
+ "description": "Python executable command (optional fallback). If not set, defaults to 'python' and uses system PATH (matching Spark's PySpark behavior).",
+ "type": "string"
+ },
+ "zeppelin.pyspark.connect.useIPython": {
+ "envName": null,
+ "propertyName": "zeppelin.pyspark.connect.useIPython",
+ "defaultValue": true,
+ "description": "Use IPython if available",
+ "type": "checkbox"
+ },
+ "zeppelin.spark.maxResult": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.maxResult",
+ "defaultValue": "1000",
+ "description": "Max number of rows to display",
+ "type": "number"
+ }
+ },
+ "editor": {
+ "language": "python",
+ "editOnDblClick": false,
+ "completionKey": "TAB",
+ "completionSupport": true
+ }
+ },
+ {
+ "group": "spark-connect",
+ "name": "ipyspark",
+ "className": "org.apache.zeppelin.spark.IPySparkConnectInterpreter",
+ "properties": {},
+ "editor": {
+ "language": "python",
+ "editOnDblClick": false,
+ "completionKey": "TAB",
+ "completionSupport": true
+ }
+ }
+]
diff --git a/spark-connect/src/main/resources/python/zeppelin_isparkconnect.py b/spark-connect/src/main/resources/python/zeppelin_isparkconnect.py
new file mode 100644
index 00000000000..17a989ddcba
--- /dev/null
+++ b/spark-connect/src/main/resources/python/zeppelin_isparkconnect.py
@@ -0,0 +1,632 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# IPython variant: same wrappers as zeppelin_sparkconnect.py.
+#
+# Guards against OOM on large collect/toPandas by enforcing configurable
+# row limits and providing safe iteration helpers.
+
+import sys
+import warnings
+
+intp = gateway.entry_point
+_jspark = intp.getSparkSession()
+_max_result = intp.getMaxResult()
+
+_COLLECT_LIMIT_DEFAULT = _max_result
+_COLLECT_WARN_THRESHOLD = 100000
+
+
+class Row(tuple):
+ """Lightweight PySpark-compatible Row that wraps values extracted from Java Row objects."""
+
+ def __new__(cls, **kwargs):
+ row = tuple.__new__(cls, kwargs.values())
+ row.__dict__['_fields'] = tuple(kwargs.keys())
+ return row
+
+ def __repr__(self):
+ pairs = ", ".join("%s=%r" % (k, v) for k, v in zip(self._fields, self))
+ return "Row(%s)" % pairs
+
+ def __getattr__(self, name):
+ try:
+ idx = self._fields.index(name)
+ return self[idx]
+ except ValueError:
+ raise AttributeError("Row has no field '%s'" % name)
+
+ def asDict(self):
+ return dict(zip(self._fields, self))
+
+
+def _convert_java_row(jrow, col_names):
+ """Convert a single Java Row to a Python Row."""
+ values = {}
+ for i, col in enumerate(col_names):
+ val = jrow.get(i)
+ if hasattr(val, 'getClass'):
+ val = str(val)
+ values[col] = val
+ return Row(**values)
+
+
+def _convert_java_rows(jdf):
+ """Convert a Java Dataset's collected rows to Python Row objects."""
+ fields = jdf.schema().fields()
+ col_names = [f.name() for f in fields]
+ jrows = jdf.collectAsList()
+ return [_convert_java_row(r, col_names) for r in jrows]
+
+
+# ---------------------------------------------------------------------------
+# Py4j / type-conversion helpers for createDataFrame and __getattr__
+# ---------------------------------------------------------------------------
+
+def _is_java_object(obj):
+ """Check if obj is a Py4j proxy."""
+ return hasattr(obj, '_get_object_id')
+
+
+def _is_java_dataset(obj):
+ """Check if a Py4j proxy represents a Spark Dataset."""
+ if not _is_java_object(obj):
+ return False
+ try:
+ return 'Dataset' in obj.getClass().getName()
+ except Exception:
+ return False
+
+
+_PYSPARK_TO_JAVA_TYPES = {
+ 'StringType': 'StringType',
+ 'IntegerType': 'IntegerType',
+ 'LongType': 'LongType',
+ 'DoubleType': 'DoubleType',
+ 'FloatType': 'FloatType',
+ 'BooleanType': 'BooleanType',
+ 'ShortType': 'ShortType',
+ 'ByteType': 'ByteType',
+ 'DateType': 'DateType',
+ 'TimestampType': 'TimestampType',
+ 'BinaryType': 'BinaryType',
+ 'NullType': 'NullType',
+}
+
+
+def _pyspark_type_to_java(dt):
+ """Convert a PySpark DataType instance to a Java DataType via Py4j gateway."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ type_name = type(dt).__name__
+ if type_name in _PYSPARK_TO_JAVA_TYPES:
+ return getattr(DataTypes, _PYSPARK_TO_JAVA_TYPES[type_name])
+ if type_name == 'DecimalType':
+ return DataTypes.createDecimalType(dt.precision, dt.scale)
+ if type_name == 'ArrayType':
+ return DataTypes.createArrayType(
+ _pyspark_type_to_java(dt.elementType), dt.containsNull)
+ if type_name == 'MapType':
+ return DataTypes.createMapType(
+ _pyspark_type_to_java(dt.keyType),
+ _pyspark_type_to_java(dt.valueType), dt.valueContainsNull)
+ if type_name == 'StructType':
+ return _pyspark_schema_to_java(dt)
+ return DataTypes.StringType
+
+
+def _pyspark_schema_to_java(pyspark_schema):
+ """Convert a PySpark StructType to a Java StructType."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ java_fields = gateway.jvm.java.util.ArrayList()
+ for field in pyspark_schema.fields:
+ jtype = _pyspark_type_to_java(field.dataType)
+ java_fields.add(DataTypes.createStructField(
+ field.name, jtype, getattr(field, 'nullable', True)))
+ return DataTypes.createStructType(java_fields)
+
+
+def _infer_java_type(value):
+ """Infer a Java DataType from a Python value."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ if value is None:
+ return DataTypes.StringType
+ if isinstance(value, bool):
+ return DataTypes.BooleanType
+ if isinstance(value, int):
+ return DataTypes.LongType if abs(value) > 2147483647 else DataTypes.IntegerType
+ if isinstance(value, float):
+ return DataTypes.DoubleType
+ return DataTypes.StringType
+
+
+def _resolve_schema(schema, data):
+ """Resolve any schema representation to a Java StructType."""
+ if schema is None:
+ return _infer_schema(data)
+ if _is_java_object(schema):
+ return schema
+ if hasattr(schema, 'fields') and not _is_java_object(schema):
+ return _pyspark_schema_to_java(schema)
+ if isinstance(schema, str):
+ try:
+ return gateway.jvm.org.apache.spark.sql.types.StructType.fromDDL(schema)
+ except Exception:
+ raise ValueError("Cannot parse DDL schema: %s" % schema)
+ if isinstance(schema, (list, tuple)) and schema and isinstance(schema[0], str):
+ return _schema_from_names(schema, data)
+ raise ValueError("Unsupported schema type: %s" % type(schema).__name__)
+
+
+def _infer_schema(data):
+ """Infer a Java StructType from the first element of the data."""
+ if not data:
+ raise ValueError("Cannot infer schema from empty data without a schema")
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ first = data[0]
+ if isinstance(first, Row):
+ names, values = list(first._fields), list(first)
+ elif isinstance(first, dict):
+ names, values = list(first.keys()), list(first.values())
+ elif isinstance(first, (list, tuple)):
+ names = ["_%d" % (i + 1) for i in range(len(first))]
+ values = list(first)
+ else:
+ names, values = ["value"], [first]
+ java_fields = gateway.jvm.java.util.ArrayList()
+ for i, name in enumerate(names):
+ java_fields.add(DataTypes.createStructField(
+ name, _infer_java_type(values[i] if i < len(values) else None), True))
+ return DataTypes.createStructType(java_fields)
+
+
+def _schema_from_names(col_names, data):
+ """Create a Java StructType from column name list, inferring types from data."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ first = data[0] if data else None
+ java_fields = gateway.jvm.java.util.ArrayList()
+ for i, name in enumerate(col_names):
+ jtype = DataTypes.StringType
+ if first is not None:
+ val = None
+ if isinstance(first, (list, tuple)) and i < len(first):
+ val = first[i]
+ elif isinstance(first, dict):
+ val = first.get(name)
+ elif isinstance(first, Row) and i < len(first):
+ val = first[i]
+ if val is not None:
+ jtype = _infer_java_type(val)
+ java_fields.add(DataTypes.createStructField(name, jtype, True))
+ return DataTypes.createStructType(java_fields)
+
+
+def _to_java_rows(data, col_names):
+ """Convert Python data (list of Row/dict/tuple/list) to Java ArrayList."""
+ RowFactory = gateway.jvm.org.apache.spark.sql.RowFactory
+ java_rows = gateway.jvm.java.util.ArrayList()
+ for item in data:
+ if isinstance(item, Row):
+ java_rows.add(RowFactory.create(*list(item)))
+ elif isinstance(item, dict):
+ java_rows.add(RowFactory.create(*[item.get(c) for c in col_names]))
+ elif isinstance(item, (list, tuple)):
+ java_rows.add(RowFactory.create(*list(item)))
+ else:
+ java_rows.add(RowFactory.create(item))
+ return java_rows
+
+
+class SparkConnectDataFrame(object):
+ """Wrapper around a Java Dataset with production-safe data retrieval."""
+
+ def __init__(self, jdf):
+ self._jdf = jdf
+
+ def show(self, n=20, truncate=True):
+ effective_n = min(n, _max_result)
+ print(intp.formatDataFrame(self._jdf, effective_n))
+
+ def collect(self, limit=None):
+ """Collect rows to the driver as Python Row objects.
+
+ Args:
+ limit: Max rows to collect. Defaults to zeppelin.spark.maxResult.
+ Pass limit=-1 to collect ALL rows (use with caution).
+ """
+ if limit is None:
+ limit = _COLLECT_LIMIT_DEFAULT
+ if limit == -1:
+ row_count = self._jdf.count()
+ if row_count > _COLLECT_WARN_THRESHOLD:
+ warnings.warn(
+ "Collecting %d rows to driver. This may cause OOM. "
+ "Consider using .limit() or .toPandas() with a smaller subset."
+ % row_count)
+ return _convert_java_rows(self._jdf)
+ return _convert_java_rows(self._jdf.limit(limit))
+
+ def take(self, n):
+ return _convert_java_rows(self._jdf.limit(n))
+
+ def head(self, n=1):
+ rows = self.take(n)
+ if n == 1:
+ return rows[0] if rows else None
+ return rows
+
+ def first(self):
+ return self.head(1)
+
+ def toPandas(self, limit=None):
+ """Convert to pandas DataFrame. Applies a safety limit.
+
+ Args:
+ limit: Max rows. Defaults to zeppelin.spark.maxResult.
+ Pass limit=-1 for all rows (use with caution on large data).
+ """
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ "pandas is required for toPandas(). "
+ "Install it with: pip install pandas")
+
+ if limit is None:
+ limit = _COLLECT_LIMIT_DEFAULT
+ if limit == -1:
+ source_jdf = self._jdf
+ else:
+ source_jdf = self._jdf.limit(limit)
+
+ fields = source_jdf.schema().fields()
+ col_names = [f.name() for f in fields]
+ jrows = source_jdf.collectAsList()
+
+ if len(jrows) == 0:
+ return pd.DataFrame(columns=col_names)
+
+ rows_data = []
+ for row in jrows:
+ rows_data.append([row.get(i) for i in range(len(col_names))])
+
+ return pd.DataFrame(rows_data, columns=col_names)
+
+ def count(self):
+ return self._jdf.count()
+
+ def limit(self, n):
+ return SparkConnectDataFrame(self._jdf.limit(n))
+
+ def filter(self, condition):
+ return SparkConnectDataFrame(self._jdf.filter(condition))
+
+ def select(self, *cols):
+ return SparkConnectDataFrame(self._jdf.select(*cols))
+
+ def where(self, condition):
+ return self.filter(condition)
+
+ def groupBy(self, *cols):
+ return self._jdf.groupBy(*cols)
+
+ def orderBy(self, *cols):
+ return SparkConnectDataFrame(self._jdf.orderBy(*cols))
+
+ def sort(self, *cols):
+ return self.orderBy(*cols)
+
+ def distinct(self):
+ return SparkConnectDataFrame(self._jdf.distinct())
+
+ def drop(self, *cols):
+ return SparkConnectDataFrame(self._jdf.drop(*cols))
+
+ def dropDuplicates(self, *cols):
+ if cols:
+ return SparkConnectDataFrame(self._jdf.dropDuplicates(*cols))
+ return SparkConnectDataFrame(self._jdf.dropDuplicates())
+
+ def join(self, other, on=None, how="inner"):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ if on is not None:
+ return SparkConnectDataFrame(self._jdf.join(other_jdf, on, how))
+ return SparkConnectDataFrame(self._jdf.join(other_jdf))
+
+ def union(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.union(other_jdf))
+
+ def withColumn(self, colName, col):
+ return SparkConnectDataFrame(self._jdf.withColumn(colName, col))
+
+ def withColumnRenamed(self, existing, new):
+ return SparkConnectDataFrame(self._jdf.withColumnRenamed(existing, new))
+
+ def cache(self):
+ self._jdf.cache()
+ return self
+
+ def persist(self, storageLevel=None):
+ if storageLevel:
+ self._jdf.persist(storageLevel)
+ else:
+ self._jdf.persist()
+ return self
+
+ def unpersist(self, blocking=False):
+ self._jdf.unpersist(blocking)
+ return self
+
+ def explain(self, extended=False):
+ if extended:
+ self._jdf.explain(True)
+ else:
+ self._jdf.explain()
+
+ def createOrReplaceTempView(self, name):
+ self._jdf.createOrReplaceTempView(name)
+
+ def createTempView(self, name):
+ self._jdf.createTempView(name)
+
+ def schema(self):
+ return self._jdf.schema()
+
+ def dtypes(self):
+ schema = self._jdf.schema()
+ return [(f.name(), str(f.dataType())) for f in schema.fields()]
+
+ def columns(self):
+ schema = self._jdf.schema()
+ return [f.name() for f in schema.fields()]
+
+ def printSchema(self):
+ print(self._jdf.schema().treeString())
+
+ def describe(self, *cols):
+ if cols:
+ return SparkConnectDataFrame(self._jdf.describe(*cols))
+ return SparkConnectDataFrame(self._jdf.describe())
+
+ def summary(self, *statistics):
+ if statistics:
+ return SparkConnectDataFrame(self._jdf.summary(*statistics))
+ return SparkConnectDataFrame(self._jdf.summary())
+
+ def isEmpty(self):
+ return self._jdf.isEmpty()
+
+ def repartition(self, numPartitions, *cols):
+ if cols:
+ return SparkConnectDataFrame(self._jdf.repartition(numPartitions, *cols))
+ return SparkConnectDataFrame(self._jdf.repartition(numPartitions))
+
+ def coalesce(self, numPartitions):
+ return SparkConnectDataFrame(self._jdf.coalesce(numPartitions))
+
+ def toDF(self, *cols):
+ return SparkConnectDataFrame(self._jdf.toDF(*cols))
+
+ def unionByName(self, other, allowMissingColumns=False):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(
+ self._jdf.unionByName(other_jdf, allowMissingColumns))
+
+ def crossJoin(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.crossJoin(other_jdf))
+
+ def subtract(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.subtract(other_jdf))
+
+ def intersect(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.intersect(other_jdf))
+
+ def exceptAll(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.exceptAll(other_jdf))
+
+ def sample(self, withReplacement=None, fraction=None, seed=None):
+ if withReplacement is None and fraction is None:
+ raise ValueError("fraction must be specified")
+ if isinstance(withReplacement, float) and fraction is None:
+ fraction = withReplacement
+ withReplacement = False
+ if withReplacement is None:
+ withReplacement = False
+ if seed is not None:
+ return SparkConnectDataFrame(
+ self._jdf.sample(withReplacement, fraction, seed))
+ return SparkConnectDataFrame(
+ self._jdf.sample(withReplacement, fraction))
+
+ def dropna(self, how="any", thresh=None, subset=None):
+ na = self._jdf.na()
+ if thresh is not None:
+ if subset:
+ return SparkConnectDataFrame(na.drop(thresh, subset))
+ return SparkConnectDataFrame(na.drop(thresh))
+ if subset:
+ return SparkConnectDataFrame(na.drop(how, subset))
+ return SparkConnectDataFrame(na.drop(how))
+
+ def fillna(self, value, subset=None):
+ na = self._jdf.na()
+ if subset:
+ return SparkConnectDataFrame(na.fill(value, subset))
+ return SparkConnectDataFrame(na.fill(value))
+
+ @property
+ def write(self):
+ return self._jdf.write()
+
+ def __repr__(self):
+ try:
+ return "SparkConnectDataFrame[%s]" % ", ".join(
+ f.name() for f in self._jdf.schema().fields())
+ except Exception:
+ return "SparkConnectDataFrame[schema unavailable]"
+
+ def __getattr__(self, name):
+ attr = getattr(self._jdf, name)
+ if not callable(attr):
+ return attr
+ def _method_wrapper(*args, **kwargs):
+ result = attr(*args, **kwargs)
+ if _is_java_dataset(result):
+ return SparkConnectDataFrame(result)
+ return result
+ return _method_wrapper
+
+ def __iter__(self):
+ """Safe iteration with default limit to prevent OOM."""
+ return iter(_convert_java_rows(self._jdf.limit(_COLLECT_LIMIT_DEFAULT)))
+
+ def __len__(self):
+ return int(self._jdf.count())
+
+
+class SparkConnectSession(object):
+ """Wraps the Java SparkSession so that sql() returns a wrapped DataFrame."""
+
+ def __init__(self, jsession):
+ self._jsession = jsession
+
+ def sql(self, query):
+ return SparkConnectDataFrame(self._jsession.sql(query))
+
+ def table(self, tableName):
+ return SparkConnectDataFrame(self._jsession.table(tableName))
+
+ def read(self):
+ return self._jsession.read()
+
+ def createDataFrame(self, data, schema=None):
+ """Create a SparkConnectDataFrame from Python data.
+
+ Supports:
+ - data: list of Row, list of tuples, list of dicts, pandas DataFrame
+ - schema: PySpark StructType, list of column names, DDL string,
+ Java StructType (Py4j proxy), or None (infer from data)
+ """
+ try:
+ import pandas as pd
+ if isinstance(data, pd.DataFrame):
+ if schema is None:
+ schema = list(data.columns)
+ data = data.values.tolist()
+ except ImportError:
+ pass
+
+ if _is_java_object(data):
+ if schema is None:
+ return SparkConnectDataFrame(self._jsession.createDataFrame(data))
+ java_schema = _resolve_schema(schema, None)
+ return SparkConnectDataFrame(
+ self._jsession.createDataFrame(data, java_schema))
+
+ java_schema = _resolve_schema(schema, data)
+ col_names = [f.name() for f in java_schema.fields()]
+ java_rows = _to_java_rows(data, col_names)
+ return SparkConnectDataFrame(
+ self._jsession.createDataFrame(java_rows, java_schema))
+
+ def range(self, start, end=None, step=1, numPartitions=None):
+ if end is None:
+ end = start
+ start = 0
+ if numPartitions:
+ return SparkConnectDataFrame(
+ self._jsession.range(start, end, step, numPartitions))
+ return SparkConnectDataFrame(self._jsession.range(start, end, step))
+
+ @property
+ def catalog(self):
+ return self._jsession.catalog()
+
+ @property
+ def version(self):
+ return self._jsession.version()
+
+ @property
+ def conf(self):
+ return self._jsession.conf()
+
+ def stop(self):
+ pass
+
+ def __repr__(self):
+ return "SparkConnectSession (via Py4j)"
+
+ def __getattr__(self, name):
+ return getattr(self._jsession, name)
+
+
+def pip_install(*packages):
+ """Install Python packages into the interpreter pod's environment.
+
+ Usage:
+ pip_install("requests")
+ pip_install("requests", "pandas", "numpy==1.24.0")
+ pip_install("requests>=2.28,<3.0")
+ """
+ import subprocess
+ import importlib
+ import site
+ if not packages:
+ print("Usage: pip_install('package1', 'package2', ...)")
+ return
+ cmd = [sys.executable, "-m", "pip", "install", "--quiet"] + list(packages)
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
+ if result.returncode == 0:
+ installed = ", ".join(packages)
+ print("Successfully installed: %s" % installed)
+ if result.stdout.strip():
+ print(result.stdout.strip())
+ importlib.invalidate_caches()
+ for new_path in site.getsitepackages() + [site.getusersitepackages()]:
+ if new_path not in sys.path:
+ sys.path.insert(0, new_path)
+ else:
+ print("pip install failed (exit code %d):" % result.returncode)
+ if result.stderr.strip():
+ print(result.stderr.strip())
+ if result.stdout.strip():
+ print(result.stdout.strip())
+ except subprocess.TimeoutExpired:
+ print("pip install timed out after 300 seconds")
+ except Exception as e:
+ print("pip install error: %s" % str(e))
+
+
+def display(obj, n=20):
+ """Databricks-compatible display function.
+
+ For SparkConnectDataFrame, renders as Zeppelin %table format.
+ For other objects, falls back to print().
+ """
+ if isinstance(obj, SparkConnectDataFrame):
+ obj.show(n)
+ else:
+ print(obj)
+
+
+spark = SparkConnectSession(_jspark)
+sqlContext = sqlc = spark
diff --git a/spark-connect/src/main/resources/python/zeppelin_sparkconnect.py b/spark-connect/src/main/resources/python/zeppelin_sparkconnect.py
new file mode 100644
index 00000000000..46250f44674
--- /dev/null
+++ b/spark-connect/src/main/resources/python/zeppelin_sparkconnect.py
@@ -0,0 +1,637 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Reuse the Java SparkSession from SparkConnectInterpreter via Py4j.
+# Wraps the Java DataFrame so that show()/collect() output goes to
+# Python stdout (which Zeppelin captures), matching the SQL interpreter behavior.
+#
+# Guards against OOM on large collect/toPandas by enforcing configurable
+# row limits and providing safe iteration helpers.
+
+import sys
+import warnings
+
+intp = gateway.entry_point
+_jspark = intp.getSparkSession()
+_max_result = intp.getMaxResult()
+
+_COLLECT_LIMIT_DEFAULT = _max_result
+_COLLECT_WARN_THRESHOLD = 100000
+
+
+class Row(tuple):
+ """Lightweight PySpark-compatible Row that wraps values extracted from Java Row objects."""
+
+ def __new__(cls, **kwargs):
+ row = tuple.__new__(cls, kwargs.values())
+ row.__dict__['_fields'] = tuple(kwargs.keys())
+ return row
+
+ def __repr__(self):
+ pairs = ", ".join("%s=%r" % (k, v) for k, v in zip(self._fields, self))
+ return "Row(%s)" % pairs
+
+ def __getattr__(self, name):
+ try:
+ idx = self._fields.index(name)
+ return self[idx]
+ except ValueError:
+ raise AttributeError("Row has no field '%s'" % name)
+
+ def asDict(self):
+ return dict(zip(self._fields, self))
+
+
+def _convert_java_row(jrow, col_names):
+ """Convert a single Java Row to a Python Row."""
+ values = {}
+ for i, col in enumerate(col_names):
+ val = jrow.get(i)
+ if hasattr(val, 'getClass'):
+ val = str(val)
+ values[col] = val
+ return Row(**values)
+
+
+def _convert_java_rows(jdf):
+ """Convert a Java Dataset's collected rows to Python Row objects."""
+ fields = jdf.schema().fields()
+ col_names = [f.name() for f in fields]
+ jrows = jdf.collectAsList()
+ return [_convert_java_row(r, col_names) for r in jrows]
+
+
+# ---------------------------------------------------------------------------
+# Py4j / type-conversion helpers for createDataFrame and __getattr__
+# ---------------------------------------------------------------------------
+
+def _is_java_object(obj):
+ """Check if obj is a Py4j proxy."""
+ return hasattr(obj, '_get_object_id')
+
+
+def _is_java_dataset(obj):
+ """Check if a Py4j proxy represents a Spark Dataset."""
+ if not _is_java_object(obj):
+ return False
+ try:
+ return 'Dataset' in obj.getClass().getName()
+ except Exception:
+ return False
+
+
+_PYSPARK_TO_JAVA_TYPES = {
+ 'StringType': 'StringType',
+ 'IntegerType': 'IntegerType',
+ 'LongType': 'LongType',
+ 'DoubleType': 'DoubleType',
+ 'FloatType': 'FloatType',
+ 'BooleanType': 'BooleanType',
+ 'ShortType': 'ShortType',
+ 'ByteType': 'ByteType',
+ 'DateType': 'DateType',
+ 'TimestampType': 'TimestampType',
+ 'BinaryType': 'BinaryType',
+ 'NullType': 'NullType',
+}
+
+
+def _pyspark_type_to_java(dt):
+ """Convert a PySpark DataType instance to a Java DataType via Py4j gateway."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ type_name = type(dt).__name__
+ if type_name in _PYSPARK_TO_JAVA_TYPES:
+ return getattr(DataTypes, _PYSPARK_TO_JAVA_TYPES[type_name])
+ if type_name == 'DecimalType':
+ return DataTypes.createDecimalType(dt.precision, dt.scale)
+ if type_name == 'ArrayType':
+ return DataTypes.createArrayType(
+ _pyspark_type_to_java(dt.elementType), dt.containsNull)
+ if type_name == 'MapType':
+ return DataTypes.createMapType(
+ _pyspark_type_to_java(dt.keyType),
+ _pyspark_type_to_java(dt.valueType), dt.valueContainsNull)
+ if type_name == 'StructType':
+ return _pyspark_schema_to_java(dt)
+ return DataTypes.StringType
+
+
+def _pyspark_schema_to_java(pyspark_schema):
+ """Convert a PySpark StructType to a Java StructType."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ java_fields = gateway.jvm.java.util.ArrayList()
+ for field in pyspark_schema.fields:
+ jtype = _pyspark_type_to_java(field.dataType)
+ java_fields.add(DataTypes.createStructField(
+ field.name, jtype, getattr(field, 'nullable', True)))
+ return DataTypes.createStructType(java_fields)
+
+
+def _infer_java_type(value):
+ """Infer a Java DataType from a Python value."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ if value is None:
+ return DataTypes.StringType
+ if isinstance(value, bool):
+ return DataTypes.BooleanType
+ if isinstance(value, int):
+ return DataTypes.LongType if abs(value) > 2147483647 else DataTypes.IntegerType
+ if isinstance(value, float):
+ return DataTypes.DoubleType
+ return DataTypes.StringType
+
+
+def _resolve_schema(schema, data):
+ """Resolve any schema representation to a Java StructType."""
+ if schema is None:
+ return _infer_schema(data)
+ if _is_java_object(schema):
+ return schema
+ if hasattr(schema, 'fields') and not _is_java_object(schema):
+ return _pyspark_schema_to_java(schema)
+ if isinstance(schema, str):
+ try:
+ return gateway.jvm.org.apache.spark.sql.types.StructType.fromDDL(schema)
+ except Exception:
+ raise ValueError("Cannot parse DDL schema: %s" % schema)
+ if isinstance(schema, (list, tuple)) and schema and isinstance(schema[0], str):
+ return _schema_from_names(schema, data)
+ raise ValueError("Unsupported schema type: %s" % type(schema).__name__)
+
+
+def _infer_schema(data):
+ """Infer a Java StructType from the first element of the data."""
+ if not data:
+ raise ValueError("Cannot infer schema from empty data without a schema")
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ first = data[0]
+ if isinstance(first, Row):
+ names, values = list(first._fields), list(first)
+ elif isinstance(first, dict):
+ names, values = list(first.keys()), list(first.values())
+ elif isinstance(first, (list, tuple)):
+ names = ["_%d" % (i + 1) for i in range(len(first))]
+ values = list(first)
+ else:
+ names, values = ["value"], [first]
+ java_fields = gateway.jvm.java.util.ArrayList()
+ for i, name in enumerate(names):
+ java_fields.add(DataTypes.createStructField(
+ name, _infer_java_type(values[i] if i < len(values) else None), True))
+ return DataTypes.createStructType(java_fields)
+
+
+def _schema_from_names(col_names, data):
+ """Create a Java StructType from column name list, inferring types from data."""
+ DataTypes = gateway.jvm.org.apache.spark.sql.types.DataTypes
+ first = data[0] if data else None
+ java_fields = gateway.jvm.java.util.ArrayList()
+ for i, name in enumerate(col_names):
+ jtype = DataTypes.StringType
+ if first is not None:
+ val = None
+ if isinstance(first, (list, tuple)) and i < len(first):
+ val = first[i]
+ elif isinstance(first, dict):
+ val = first.get(name)
+ elif isinstance(first, Row) and i < len(first):
+ val = first[i]
+ if val is not None:
+ jtype = _infer_java_type(val)
+ java_fields.add(DataTypes.createStructField(name, jtype, True))
+ return DataTypes.createStructType(java_fields)
+
+
+def _to_java_rows(data, col_names):
+ """Convert Python data (list of Row/dict/tuple/list) to Java ArrayList."""
+ RowFactory = gateway.jvm.org.apache.spark.sql.RowFactory
+ java_rows = gateway.jvm.java.util.ArrayList()
+ for item in data:
+ if isinstance(item, Row):
+ java_rows.add(RowFactory.create(*list(item)))
+ elif isinstance(item, dict):
+ java_rows.add(RowFactory.create(*[item.get(c) for c in col_names]))
+ elif isinstance(item, (list, tuple)):
+ java_rows.add(RowFactory.create(*list(item)))
+ else:
+ java_rows.add(RowFactory.create(item))
+ return java_rows
+
+
+class SparkConnectDataFrame(object):
+ """Wrapper around a Java Dataset with production-safe data retrieval."""
+
+ def __init__(self, jdf):
+ self._jdf = jdf
+
+ def show(self, n=20, truncate=True):
+ effective_n = min(n, _max_result)
+ print(intp.formatDataFrame(self._jdf, effective_n))
+
+ def collect(self, limit=None):
+ """Collect rows to the driver as Python Row objects.
+
+ Args:
+ limit: Max rows to collect. Defaults to zeppelin.spark.maxResult.
+ Pass limit=-1 to collect ALL rows (use with caution).
+ """
+ if limit is None:
+ limit = _COLLECT_LIMIT_DEFAULT
+ if limit == -1:
+ row_count = self._jdf.count()
+ if row_count > _COLLECT_WARN_THRESHOLD:
+ warnings.warn(
+ "Collecting %d rows to driver. This may cause OOM. "
+ "Consider using .limit() or .toPandas() with a smaller subset."
+ % row_count)
+ return _convert_java_rows(self._jdf)
+ return _convert_java_rows(self._jdf.limit(limit))
+
+ def take(self, n):
+ return _convert_java_rows(self._jdf.limit(n))
+
+ def head(self, n=1):
+ rows = self.take(n)
+ if n == 1:
+ return rows[0] if rows else None
+ return rows
+
+ def first(self):
+ return self.head(1)
+
+ def toPandas(self, limit=None):
+ """Convert to pandas DataFrame. Applies a safety limit.
+
+ Tries to use pyarrow for efficient serialization if available,
+ otherwise falls back to row-by-row conversion through Py4j.
+
+ Args:
+ limit: Max rows. Defaults to zeppelin.spark.maxResult.
+ Pass limit=-1 for all rows (use with caution on large data).
+ """
+ try:
+ import pandas as pd
+ except ImportError:
+ raise ImportError(
+ "pandas is required for toPandas(). "
+ "Install it with: pip install pandas")
+
+ if limit is None:
+ limit = _COLLECT_LIMIT_DEFAULT
+ if limit == -1:
+ source_jdf = self._jdf
+ else:
+ source_jdf = self._jdf.limit(limit)
+
+ fields = source_jdf.schema().fields()
+ col_names = [f.name() for f in fields]
+ jrows = source_jdf.collectAsList()
+
+ if len(jrows) == 0:
+ return pd.DataFrame(columns=col_names)
+
+ rows_data = []
+ for row in jrows:
+ rows_data.append([row.get(i) for i in range(len(col_names))])
+
+ return pd.DataFrame(rows_data, columns=col_names)
+
+ def count(self):
+ return self._jdf.count()
+
+ def limit(self, n):
+ return SparkConnectDataFrame(self._jdf.limit(n))
+
+ def filter(self, condition):
+ return SparkConnectDataFrame(self._jdf.filter(condition))
+
+ def select(self, *cols):
+ return SparkConnectDataFrame(self._jdf.select(*cols))
+
+ def where(self, condition):
+ return self.filter(condition)
+
+ def groupBy(self, *cols):
+ return self._jdf.groupBy(*cols)
+
+ def orderBy(self, *cols):
+ return SparkConnectDataFrame(self._jdf.orderBy(*cols))
+
+ def sort(self, *cols):
+ return self.orderBy(*cols)
+
+ def distinct(self):
+ return SparkConnectDataFrame(self._jdf.distinct())
+
+ def drop(self, *cols):
+ return SparkConnectDataFrame(self._jdf.drop(*cols))
+
+ def dropDuplicates(self, *cols):
+ if cols:
+ return SparkConnectDataFrame(self._jdf.dropDuplicates(*cols))
+ return SparkConnectDataFrame(self._jdf.dropDuplicates())
+
+ def join(self, other, on=None, how="inner"):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ if on is not None:
+ return SparkConnectDataFrame(self._jdf.join(other_jdf, on, how))
+ return SparkConnectDataFrame(self._jdf.join(other_jdf))
+
+ def union(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.union(other_jdf))
+
+ def withColumn(self, colName, col):
+ return SparkConnectDataFrame(self._jdf.withColumn(colName, col))
+
+ def withColumnRenamed(self, existing, new):
+ return SparkConnectDataFrame(self._jdf.withColumnRenamed(existing, new))
+
+ def cache(self):
+ self._jdf.cache()
+ return self
+
+ def persist(self, storageLevel=None):
+ if storageLevel:
+ self._jdf.persist(storageLevel)
+ else:
+ self._jdf.persist()
+ return self
+
+ def unpersist(self, blocking=False):
+ self._jdf.unpersist(blocking)
+ return self
+
+ def explain(self, extended=False):
+ if extended:
+ self._jdf.explain(True)
+ else:
+ self._jdf.explain()
+
+ def createOrReplaceTempView(self, name):
+ self._jdf.createOrReplaceTempView(name)
+
+ def createTempView(self, name):
+ self._jdf.createTempView(name)
+
+ def schema(self):
+ return self._jdf.schema()
+
+ def dtypes(self):
+ schema = self._jdf.schema()
+ return [(f.name(), str(f.dataType())) for f in schema.fields()]
+
+ def columns(self):
+ schema = self._jdf.schema()
+ return [f.name() for f in schema.fields()]
+
+ def printSchema(self):
+ print(self._jdf.schema().treeString())
+
+ def describe(self, *cols):
+ if cols:
+ return SparkConnectDataFrame(self._jdf.describe(*cols))
+ return SparkConnectDataFrame(self._jdf.describe())
+
+ def summary(self, *statistics):
+ if statistics:
+ return SparkConnectDataFrame(self._jdf.summary(*statistics))
+ return SparkConnectDataFrame(self._jdf.summary())
+
+ def isEmpty(self):
+ return self._jdf.isEmpty()
+
+ def __repr__(self):
+ try:
+ return "SparkConnectDataFrame[%s]" % ", ".join(
+ f.name() for f in self._jdf.schema().fields())
+ except Exception:
+ return "SparkConnectDataFrame[schema unavailable]"
+
+ def repartition(self, numPartitions, *cols):
+ if cols:
+ return SparkConnectDataFrame(self._jdf.repartition(numPartitions, *cols))
+ return SparkConnectDataFrame(self._jdf.repartition(numPartitions))
+
+ def coalesce(self, numPartitions):
+ return SparkConnectDataFrame(self._jdf.coalesce(numPartitions))
+
+ def toDF(self, *cols):
+ return SparkConnectDataFrame(self._jdf.toDF(*cols))
+
+ def unionByName(self, other, allowMissingColumns=False):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(
+ self._jdf.unionByName(other_jdf, allowMissingColumns))
+
+ def crossJoin(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.crossJoin(other_jdf))
+
+ def subtract(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.subtract(other_jdf))
+
+ def intersect(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.intersect(other_jdf))
+
+ def exceptAll(self, other):
+ other_jdf = other._jdf if isinstance(other, SparkConnectDataFrame) else other
+ return SparkConnectDataFrame(self._jdf.exceptAll(other_jdf))
+
+ def sample(self, withReplacement=None, fraction=None, seed=None):
+ if withReplacement is None and fraction is None:
+ raise ValueError("fraction must be specified")
+ if isinstance(withReplacement, float) and fraction is None:
+ fraction = withReplacement
+ withReplacement = False
+ if withReplacement is None:
+ withReplacement = False
+ if seed is not None:
+ return SparkConnectDataFrame(
+ self._jdf.sample(withReplacement, fraction, seed))
+ return SparkConnectDataFrame(
+ self._jdf.sample(withReplacement, fraction))
+
+ def dropna(self, how="any", thresh=None, subset=None):
+ na = self._jdf.na()
+ if thresh is not None:
+ if subset:
+ return SparkConnectDataFrame(na.drop(thresh, subset))
+ return SparkConnectDataFrame(na.drop(thresh))
+ if subset:
+ return SparkConnectDataFrame(na.drop(how, subset))
+ return SparkConnectDataFrame(na.drop(how))
+
+ def fillna(self, value, subset=None):
+ na = self._jdf.na()
+ if subset:
+ return SparkConnectDataFrame(na.fill(value, subset))
+ return SparkConnectDataFrame(na.fill(value))
+
+ @property
+ def write(self):
+ return self._jdf.write()
+
+ def __getattr__(self, name):
+ attr = getattr(self._jdf, name)
+ if not callable(attr):
+ return attr
+ def _method_wrapper(*args, **kwargs):
+ result = attr(*args, **kwargs)
+ if _is_java_dataset(result):
+ return SparkConnectDataFrame(result)
+ return result
+ return _method_wrapper
+
+ def __iter__(self):
+ """Safe iteration with default limit to prevent OOM."""
+ return iter(_convert_java_rows(self._jdf.limit(_COLLECT_LIMIT_DEFAULT)))
+
+ def __len__(self):
+ return int(self._jdf.count())
+
+
+class SparkConnectSession(object):
+ """Wraps the Java SparkSession so that sql() returns a wrapped DataFrame."""
+
+ def __init__(self, jsession):
+ self._jsession = jsession
+
+ def sql(self, query):
+ return SparkConnectDataFrame(self._jsession.sql(query))
+
+ def table(self, tableName):
+ return SparkConnectDataFrame(self._jsession.table(tableName))
+
+ def read(self):
+ return self._jsession.read()
+
+ def createDataFrame(self, data, schema=None):
+ """Create a SparkConnectDataFrame from Python data.
+
+ Supports:
+ - data: list of Row, list of tuples, list of dicts, pandas DataFrame
+ - schema: PySpark StructType, list of column names, DDL string,
+ Java StructType (Py4j proxy), or None (infer from data)
+ """
+ try:
+ import pandas as pd
+ if isinstance(data, pd.DataFrame):
+ if schema is None:
+ schema = list(data.columns)
+ data = data.values.tolist()
+ except ImportError:
+ pass
+
+ if _is_java_object(data):
+ if schema is None:
+ return SparkConnectDataFrame(self._jsession.createDataFrame(data))
+ java_schema = _resolve_schema(schema, None)
+ return SparkConnectDataFrame(
+ self._jsession.createDataFrame(data, java_schema))
+
+ java_schema = _resolve_schema(schema, data)
+ col_names = [f.name() for f in java_schema.fields()]
+ java_rows = _to_java_rows(data, col_names)
+ return SparkConnectDataFrame(
+ self._jsession.createDataFrame(java_rows, java_schema))
+
+ def range(self, start, end=None, step=1, numPartitions=None):
+ if end is None:
+ end = start
+ start = 0
+ if numPartitions:
+ return SparkConnectDataFrame(
+ self._jsession.range(start, end, step, numPartitions))
+ return SparkConnectDataFrame(self._jsession.range(start, end, step))
+
+ @property
+ def catalog(self):
+ return self._jsession.catalog()
+
+ @property
+ def version(self):
+ return self._jsession.version()
+
+ @property
+ def conf(self):
+ return self._jsession.conf()
+
+ def stop(self):
+ pass
+
+ def __repr__(self):
+ return "SparkConnectSession (via Py4j)"
+
+ def __getattr__(self, name):
+ return getattr(self._jsession, name)
+
+
+def pip_install(*packages):
+ """Install Python packages into the interpreter pod's environment.
+
+ Usage:
+ pip_install("requests")
+ pip_install("requests", "pandas", "numpy==1.24.0")
+ pip_install("requests>=2.28,<3.0")
+ """
+ import subprocess
+ import importlib
+ import site
+ if not packages:
+ print("Usage: pip_install('package1', 'package2', ...)")
+ return
+ cmd = [sys.executable, "-m", "pip", "install", "--quiet"] + list(packages)
+ try:
+ result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
+ if result.returncode == 0:
+ installed = ", ".join(packages)
+ print("Successfully installed: %s" % installed)
+ if result.stdout.strip():
+ print(result.stdout.strip())
+ importlib.invalidate_caches()
+ for new_path in site.getsitepackages() + [site.getusersitepackages()]:
+ if new_path not in sys.path:
+ sys.path.insert(0, new_path)
+ else:
+ print("pip install failed (exit code %d):" % result.returncode)
+ if result.stderr.strip():
+ print(result.stderr.strip())
+ if result.stdout.strip():
+ print(result.stdout.strip())
+ except subprocess.TimeoutExpired:
+ print("pip install timed out after 300 seconds")
+ except Exception as e:
+ print("pip install error: %s" % str(e))
+
+
+def display(obj, n=20):
+ """Databricks-compatible display function.
+
+ For SparkConnectDataFrame, renders as Zeppelin %table format.
+ For other objects, falls back to print().
+ """
+ if isinstance(obj, SparkConnectDataFrame):
+ obj.show(n)
+ else:
+ print(obj)
+
+
+spark = SparkConnectSession(_jspark)
+sqlContext = sqlc = spark
diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/PySparkConnectInterpreterTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/PySparkConnectInterpreterTest.java
new file mode 100644
index 00000000000..a8fe2cc4c52
--- /dev/null
+++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/PySparkConnectInterpreterTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.zeppelin.display.AngularObjectRegistry;
+import org.apache.zeppelin.interpreter.Interpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterGroup;
+import org.apache.zeppelin.interpreter.InterpreterOutput;
+import org.apache.zeppelin.interpreter.InterpreterResult;
+import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient;
+import org.apache.zeppelin.resource.LocalResourcePool;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.Properties;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Integration tests for PySparkConnectInterpreter.
+ * Requires a running Spark Connect server.
+ * Set SPARK_CONNECT_TEST_REMOTE env var (e.g. sc://localhost:15002) to enable.
+ */
+@EnabledIfEnvironmentVariable(named = "SPARK_CONNECT_TEST_REMOTE", matches = ".+")
+public class PySparkConnectInterpreterTest {
+
+ private static PySparkConnectInterpreter interpreter;
+ private static SparkConnectInterpreter sparkConnectInterpreter;
+ private static InterpreterGroup intpGroup;
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ String remote = System.getenv("SPARK_CONNECT_TEST_REMOTE");
+ Properties p = new Properties();
+ p.setProperty("spark.remote", remote);
+ p.setProperty("spark.app.name", "ZeppelinPySparkConnectTest");
+ p.setProperty("zeppelin.spark.maxResult", "100");
+ p.setProperty("zeppelin.pyspark.connect.useIPython", "false");
+ p.setProperty("zeppelin.python", "python");
+
+ intpGroup = new InterpreterGroup();
+
+ // Create SparkConnectInterpreter first (required dependency)
+ sparkConnectInterpreter = new SparkConnectInterpreter(p);
+ sparkConnectInterpreter.setInterpreterGroup(intpGroup);
+ intpGroup.put("session_1", new LinkedList());
+ intpGroup.get("session_1").add(sparkConnectInterpreter);
+ sparkConnectInterpreter.open();
+
+ // Create PySparkConnectInterpreter
+ interpreter = new PySparkConnectInterpreter(p);
+ interpreter.setInterpreterGroup(intpGroup);
+ intpGroup.get("session_1").add(interpreter);
+ interpreter.open();
+ }
+
+ @AfterAll
+ public static void tearDown() throws InterpreterException {
+ if (interpreter != null) {
+ interpreter.close();
+ }
+ if (sparkConnectInterpreter != null) {
+ sparkConnectInterpreter.close();
+ }
+ }
+
+ private static InterpreterContext getInterpreterContext() {
+ return InterpreterContext.builder()
+ .setNoteId("noteId")
+ .setParagraphId("paragraphId")
+ .setParagraphTitle("title")
+ .setAngularObjectRegistry(new AngularObjectRegistry(intpGroup.getId(), null))
+ .setResourcePool(new LocalResourcePool("id"))
+ .setInterpreterOut(new InterpreterOutput())
+ .setIntpEventClient(mock(RemoteInterpreterEventClient.class))
+ .build();
+ }
+
+ @Test
+ void testSparkSessionCreated() {
+ assertNotNull(interpreter.getSparkConnectInterpreter());
+ assertNotNull(interpreter.getSparkConnectInterpreter().getSparkSession());
+ }
+
+ @Test
+ void testSimpleQuery() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "df = spark.sql(\"SELECT 1 AS id, 'hello' AS message\")\ndf.show()", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("id") || output.contains("message") || output.contains("hello"),
+ "Output should contain query results: " + output);
+ }
+
+ @Test
+ void testDataFrameVariable() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "df = spark.sql(\"SELECT 1 AS id, 'test' AS name\")\nprint(type(df))", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("DataFrame") || output.contains("pyspark"),
+ "Output should indicate DataFrame type: " + output);
+ }
+
+ @Test
+ void testDeltaTableQuery() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ // Test the exact query from user
+ InterpreterResult result = interpreter.interpret(
+ "df = spark.sql(\"select * from gold.delta_test\")", context);
+ // This might fail if table doesn't exist, but should not crash
+ // We check that interpreter handled it gracefully
+ assertTrue(result.code() == InterpreterResult.Code.SUCCESS
+ || result.code() == InterpreterResult.Code.ERROR,
+ "Should handle query execution (success or error): " + result.code());
+ }
+
+ @Test
+ void testSparkVariableAvailable() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "print('Spark session:', type(spark))", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("SparkSession") || output.contains("spark"),
+ "Spark session should be available: " + output);
+ }
+}
diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectInterpreterTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectInterpreterTest.java
new file mode 100644
index 00000000000..0b3cbc8213b
--- /dev/null
+++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectInterpreterTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.zeppelin.display.AngularObjectRegistry;
+import org.apache.zeppelin.interpreter.Interpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterGroup;
+import org.apache.zeppelin.interpreter.InterpreterOutput;
+import org.apache.zeppelin.interpreter.InterpreterResult;
+import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient;
+import org.apache.zeppelin.resource.LocalResourcePool;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.Properties;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Integration tests for SparkConnectInterpreter.
+ * Requires a running Spark Connect server.
+ * Set SPARK_CONNECT_TEST_REMOTE env var (e.g. sc://localhost:15002) to enable.
+ */
+@EnabledIfEnvironmentVariable(named = "SPARK_CONNECT_TEST_REMOTE", matches = ".+")
+public class SparkConnectInterpreterTest {
+
+ private static SparkConnectInterpreter interpreter;
+ private static InterpreterGroup intpGroup;
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ String remote = System.getenv("SPARK_CONNECT_TEST_REMOTE");
+ Properties p = new Properties();
+ p.setProperty("spark.remote", remote);
+ p.setProperty("spark.app.name", "ZeppelinSparkConnectTest");
+ p.setProperty("zeppelin.spark.maxResult", "100");
+ p.setProperty("zeppelin.spark.sql.stacktrace", "true");
+
+ intpGroup = new InterpreterGroup();
+ interpreter = new SparkConnectInterpreter(p);
+ interpreter.setInterpreterGroup(intpGroup);
+ intpGroup.put("session_1", new LinkedList());
+ intpGroup.get("session_1").add(interpreter);
+
+ interpreter.open();
+ }
+
+ @AfterAll
+ public static void tearDown() throws InterpreterException {
+ if (interpreter != null) {
+ interpreter.close();
+ }
+ }
+
+ private static InterpreterContext getInterpreterContext() {
+ return InterpreterContext.builder()
+ .setNoteId("noteId")
+ .setParagraphId("paragraphId")
+ .setParagraphTitle("title")
+ .setAngularObjectRegistry(new AngularObjectRegistry(intpGroup.getId(), null))
+ .setResourcePool(new LocalResourcePool("id"))
+ .setInterpreterOut(new InterpreterOutput())
+ .setIntpEventClient(mock(RemoteInterpreterEventClient.class))
+ .build();
+ }
+
+ @Test
+ void testSparkSessionCreated() {
+ assertNotNull(interpreter.getSparkSession());
+ }
+
+ @Test
+ void testSimpleQuery() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "SELECT 1 AS id, 'hello' AS message", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("id"));
+ assertTrue(output.contains("message"));
+ assertTrue(output.contains("hello"));
+ }
+
+ @Test
+ void testMultipleStatements() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "SELECT 1 AS a; SELECT 2 AS b", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ }
+
+ @Test
+ void testInvalidSQL() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "SELECT FROM WHERE INVALID", context);
+ assertEquals(InterpreterResult.Code.ERROR, result.code());
+ }
+
+ @Test
+ void testMaxResultLimit() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ context.getLocalProperties().put("limit", "5");
+ InterpreterResult result = interpreter.interpret(
+ "SELECT id FROM range(100)", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ String[] lines = output.split("\n");
+ // header + 5 data rows
+ assertTrue(lines.length <= 7);
+ }
+
+ @Test
+ void testDDL() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = interpreter.interpret(
+ "CREATE OR REPLACE TEMP VIEW test_view AS SELECT 1 AS id, 'test' AS name", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+
+ context = getInterpreterContext();
+ result = interpreter.interpret("SELECT * FROM test_view", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("test"));
+ }
+
+ @Test
+ void testFormType() {
+ assertEquals(Interpreter.FormType.SIMPLE, interpreter.getFormType());
+ }
+
+ @Test
+ void testProgress() throws InterpreterException {
+ InterpreterContext context = getInterpreterContext();
+ assertEquals(0, interpreter.getProgress(context));
+ }
+}
diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreterTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreterTest.java
new file mode 100644
index 00000000000..acee3ed7a44
--- /dev/null
+++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectSqlInterpreterTest.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.zeppelin.display.AngularObjectRegistry;
+import org.apache.zeppelin.interpreter.Interpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterGroup;
+import org.apache.zeppelin.interpreter.InterpreterOutput;
+import org.apache.zeppelin.interpreter.InterpreterResult;
+import org.apache.zeppelin.interpreter.remote.RemoteInterpreterEventClient;
+import org.apache.zeppelin.resource.LocalResourcePool;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
+
+import java.io.IOException;
+import java.util.LinkedList;
+import java.util.Properties;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Integration tests for SparkConnectSqlInterpreter.
+ * Requires a running Spark Connect server.
+ * Set SPARK_CONNECT_TEST_REMOTE env var (e.g. sc://localhost:15002) to enable.
+ */
+@EnabledIfEnvironmentVariable(named = "SPARK_CONNECT_TEST_REMOTE", matches = ".+")
+public class SparkConnectSqlInterpreterTest {
+
+ private static SparkConnectInterpreter connectInterpreter;
+ private static SparkConnectSqlInterpreter sqlInterpreter;
+ private static InterpreterGroup intpGroup;
+
+ @BeforeAll
+ public static void setUp() throws Exception {
+ String remote = System.getenv("SPARK_CONNECT_TEST_REMOTE");
+ Properties p = new Properties();
+ p.setProperty("spark.remote", remote);
+ p.setProperty("spark.app.name", "ZeppelinSparkConnectSqlTest");
+ p.setProperty("zeppelin.spark.maxResult", "100");
+ p.setProperty("zeppelin.spark.concurrentSQL", "true");
+ p.setProperty("zeppelin.spark.concurrentSQL.max", "10");
+ p.setProperty("zeppelin.spark.sql.stacktrace", "true");
+ p.setProperty("zeppelin.spark.sql.interpolation", "false");
+
+ intpGroup = new InterpreterGroup();
+ connectInterpreter = new SparkConnectInterpreter(p);
+ connectInterpreter.setInterpreterGroup(intpGroup);
+
+ sqlInterpreter = new SparkConnectSqlInterpreter(p);
+ sqlInterpreter.setInterpreterGroup(intpGroup);
+
+ intpGroup.put("session_1", new LinkedList());
+ intpGroup.get("session_1").add(connectInterpreter);
+ intpGroup.get("session_1").add(sqlInterpreter);
+
+ connectInterpreter.open();
+ sqlInterpreter.open();
+ }
+
+ @AfterAll
+ public static void tearDown() throws InterpreterException {
+ if (sqlInterpreter != null) {
+ sqlInterpreter.close();
+ }
+ if (connectInterpreter != null) {
+ connectInterpreter.close();
+ }
+ }
+
+ private static InterpreterContext getInterpreterContext() {
+ return InterpreterContext.builder()
+ .setNoteId("noteId")
+ .setParagraphId("paragraphId")
+ .setParagraphTitle("title")
+ .setAngularObjectRegistry(new AngularObjectRegistry(intpGroup.getId(), null))
+ .setResourcePool(new LocalResourcePool("id"))
+ .setInterpreterOut(new InterpreterOutput())
+ .setIntpEventClient(mock(RemoteInterpreterEventClient.class))
+ .build();
+ }
+
+ @Test
+ void testSimpleQuery() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = sqlInterpreter.interpret(
+ "SELECT 1 AS id, 'hello' AS message", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("id"));
+ assertTrue(output.contains("message"));
+ assertTrue(output.contains("hello"));
+ }
+
+ @Test
+ void testMultipleStatements() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = sqlInterpreter.interpret(
+ "SELECT 1 AS a; SELECT 2 AS b", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ assertEquals(2, context.out.toInterpreterResultMessage().size());
+ }
+
+ @Test
+ void testInvalidSQL() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = sqlInterpreter.interpret(
+ "SELECT FROM WHERE INVALID", context);
+ assertEquals(InterpreterResult.Code.ERROR, result.code());
+ assertTrue(context.out.toInterpreterResultMessage().get(0).getData().length() > 0);
+ }
+
+ @Test
+ void testMaxResultLimit() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ context.getLocalProperties().put("limit", "3");
+ InterpreterResult result = sqlInterpreter.interpret(
+ "SELECT id FROM range(100)", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ String[] lines = output.split("\n");
+ // header + 3 data rows
+ assertTrue(lines.length <= 5);
+ }
+
+ @Test
+ void testCreateAndQuery() throws InterpreterException, IOException {
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = sqlInterpreter.interpret(
+ "CREATE OR REPLACE TEMP VIEW sql_test AS SELECT 42 AS answer", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+
+ context = getInterpreterContext();
+ result = sqlInterpreter.interpret("SELECT * FROM sql_test", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ String output = context.out.toInterpreterResultMessage().get(0).getData();
+ assertTrue(output.contains("42"));
+ }
+
+ @Test
+ void testFormType() {
+ assertEquals(Interpreter.FormType.SIMPLE, sqlInterpreter.getFormType());
+ }
+}
diff --git a/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectUtilsTest.java b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectUtilsTest.java
new file mode 100644
index 00000000000..ebafd1a9bef
--- /dev/null
+++ b/spark-connect/src/test/java/org/apache/zeppelin/spark/SparkConnectUtilsTest.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Properties;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class SparkConnectUtilsTest {
+
+ @Test
+ void testBuildConnectionStringDefault() {
+ Properties props = new Properties();
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://localhost:15002", result);
+ }
+
+ @Test
+ void testBuildConnectionStringCustomRemote() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://spark-server.example.com:15002");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://spark-server.example.com:15002", result);
+ }
+
+ @Test
+ void testBuildConnectionStringWithToken() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://spark-server.example.com:15002");
+ props.setProperty("spark.connect.token", "my-secret-token");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://spark-server.example.com:15002/;token=my-secret-token", result);
+ }
+
+ @Test
+ void testBuildConnectionStringWithTokenAndExistingParams() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://host:15002/;use_ssl=true");
+ props.setProperty("spark.connect.token", "tok123");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://host:15002/;use_ssl=true;token=tok123", result);
+ }
+
+ @Test
+ void testBuildConnectionStringEmptyToken() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://host:15002");
+ props.setProperty("spark.connect.token", "");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://host:15002", result);
+ }
+
+ @Test
+ void testBuildConnectionStringWithUseSsl() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://ranking-cluster-m:8080");
+ props.setProperty("spark.connect.use_ssl", "true");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://ranking-cluster-m:8080/;use_ssl=true", result);
+ }
+
+ @Test
+ void testBuildConnectionStringWithSslAndToken() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://cluster:8080");
+ props.setProperty("spark.connect.use_ssl", "true");
+ props.setProperty("spark.connect.token", "abc123");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://cluster:8080/;token=abc123;use_ssl=true", result);
+ }
+
+ @Test
+ void testBuildConnectionStringDataprocTunnel() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://localhost:15002");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://localhost:15002", result);
+ }
+
+ @Test
+ void testBuildConnectionStringDataprocDirect() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://ranking-cluster-m:8080");
+ String result = SparkConnectUtils.buildConnectionString(props);
+ assertEquals("sc://ranking-cluster-m:8080", result);
+ }
+
+ @Test
+ void testBuildConnectionStringWithUserName() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://cluster:8080");
+ String result = SparkConnectUtils.buildConnectionString(props, "alice");
+ assertEquals("sc://cluster:8080/;user_id=alice", result);
+ }
+
+ @Test
+ void testBuildConnectionStringWithUserNameAndToken() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://cluster:8080");
+ props.setProperty("spark.connect.token", "tok");
+ String result = SparkConnectUtils.buildConnectionString(props, "bob");
+ assertEquals("sc://cluster:8080/;token=tok;user_id=bob", result);
+ }
+
+ @Test
+ void testBuildConnectionStringUserIdAlreadyInUrl() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://cluster:8080/;user_id=preexisting");
+ String result = SparkConnectUtils.buildConnectionString(props, "alice");
+ assertEquals("sc://cluster:8080/;user_id=preexisting", result);
+ }
+
+ @Test
+ void testBuildConnectionStringNullUserName() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://cluster:8080");
+ String result = SparkConnectUtils.buildConnectionString(props, null);
+ assertEquals("sc://cluster:8080", result);
+ }
+
+ @Test
+ void testBuildConnectionStringBlankUserName() {
+ Properties props = new Properties();
+ props.setProperty("spark.remote", "sc://cluster:8080");
+ String result = SparkConnectUtils.buildConnectionString(props, " ");
+ assertEquals("sc://cluster:8080", result);
+ }
+
+ @Test
+ void testReplaceReservedChars() {
+ assertEquals("hello world", SparkConnectUtils.replaceReservedChars("hello\tworld"));
+ assertEquals("hello world", SparkConnectUtils.replaceReservedChars("hello\nworld"));
+ assertEquals("null", SparkConnectUtils.replaceReservedChars(null));
+ assertEquals("normal", SparkConnectUtils.replaceReservedChars("normal"));
+ assertEquals("a b c", SparkConnectUtils.replaceReservedChars("a\tb\nc"));
+ }
+}