Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 103 additions & 6 deletions jaydebeapiarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def reraise(tp, value, tb=None):

old_jpype = False

_jvm_started_pid = None

def _handle_sql_exception_jpype():
import jpype
SQLException = jpype.java.sql.SQLException
Expand All @@ -128,8 +130,90 @@ def _handle_sql_exception_jpype():

reraise(exc_type, exc_info[1], exc_info[2])

def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs):
def _dynamic_load_driver(jclassname, jars):
"""Load a JDBC driver from JARs after JVM start using the DriverShim pattern.

Java's DriverManager refuses to use drivers not loaded by the system
classloader. This function works around that restriction by creating a
URLClassLoader for the new JARs, instantiating the driver through it, and
registering a ``DriverShim`` proxy (loaded on the system classloader) with
DriverManager.

Args:
jclassname: Fully-qualified Java class name of the JDBC driver.
jars: List of JAR file paths to load the driver from.

Returns:
The URLClassLoader used to load the driver.
"""
import jpype

# Build URLClassLoader with the new JARs, parented to system classloader
urls = [jpype.java.io.File(j).toURI().toURL() for j in jars]
url_cl = jpype.java.net.URLClassLoader(
urls, jpype.java.lang.ClassLoader.getSystemClassLoader()
)

# Load driver class from custom classloader and instantiate
driver_cls = url_cl.loadClass(jclassname)
driver = driver_cls.getDeclaredConstructor().newInstance()

# Create a DriverShim proxy and register it with DriverManager.
# The shim is a Python-implemented java.sql.Driver that delegates
# every call to the real driver loaded via URLClassLoader.
# DriverManager accepts the shim because it is loaded by the system CL.

@jpype.JImplements("java.sql.Driver")
class DriverShim:
def __init__(self, _driver):
self._driver = _driver
@jpype.JOverride
def connect(self, u, info):
return self._driver.connect(u, info)
@jpype.JOverride
def acceptsURL(self, u):
return self._driver.acceptsURL(u)
@jpype.JOverride
def getPropertyInfo(self, u, info):
return self._driver.getPropertyInfo(u, info)
@jpype.JOverride
def getMajorVersion(self):
return self._driver.getMajorVersion()
@jpype.JOverride
def getMinorVersion(self):
return self._driver.getMinorVersion()
@jpype.JOverride
def jdbcCompliant(self):
return self._driver.jdbcCompliant()
@jpype.JOverride
def getParentLogger(self):
return self._driver.getParentLogger()

jpype.java.sql.DriverManager.registerDriver(DriverShim(driver))

# Update thread context classloader so the driver can find its own resources
jpype.java.lang.Thread.currentThread().setContextClassLoader(url_cl)

return url_cl


def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs, experimental=None):
import jpype
global _jvm_started_pid

_experimental = experimental or {}

if _jvm_started_pid is not None and _jvm_started_pid != os.getpid():
if not _experimental.get('dynamic_classpath'):
raise InterfaceError(
"Cannot use jaydebeapiarrow in a forked process. "
"The JVM was started in the parent process (PID %d) but this is "
"PID %d. JPype does not support fork after JVM start. "
"Move the connect() call after the fork, or use a "
"post-fork-spawn worker model (e.g. gunicorn --preload with "
"lazy connections)." % (_jvm_started_pid, os.getpid())
)

if not _is_jvm_started():
class_path = []
if jars:
Expand Down Expand Up @@ -177,17 +261,22 @@ def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs):
jpype.startJVM(jvm_path, *args, ignoreUnrecognized=True,
convertStrings=True,
classpath=class_path)

_jvm_started_pid = os.getpid()

if not jpype.java.lang.Thread.isAttached():
jpype.attachThreadToJVM()
jpype.java.lang.Thread.currentThread().setContextClassLoader(jpype.java.lang.ClassLoader.getSystemClassLoader())
try:
import pyarrow.jvm
except ImportError as e:
raise RuntimeError(f"Failed to import pyarrow.jvm ({e}). Looks like JVM is not started. Thisis required for jaydebeapiarrow to work.")

# register driver for DriverManager
jpype.JClass(jclassname)
if _experimental.get('dynamic_classpath') and jars and _is_jvm_started():
_dynamic_load_driver(jclassname, jars)
else:
jpype.JClass(jclassname)

if isinstance(driver_args, dict):
Properties = jpype.java.util.Properties
info = Properties()
Expand Down Expand Up @@ -379,7 +468,7 @@ def TimestampFromTicks(ticks):
return Timestamp(*time.localtime(ticks)[:6])

# DB-API 2.0 Module Interface connect constructor
def connect(jclassname, url, driver_args=None, jars=None, libs=None):
def connect(jclassname, url, driver_args=None, jars=None, libs=None, experimental=None):
"""Open a connection to a database using a JDBC driver and return
a Connection instance.

Expand All @@ -395,6 +484,12 @@ def connect(jclassname, url, driver_args=None, jars=None, libs=None):
jars: Jar filename or sequence of filenames for the JDBC driver
libs: Dll/so filenames or sequence of dlls/sos used as shared
library by the JDBC driver
experimental: Optional dict of experimental feature flags.
Supported keys:
dynamic_classpath (bool): If True, allow loading JDBC drivers
from JARs after the JVM has already been started, using a
DriverShim proxy. This also bypasses the fork-after-JVM-start
guard, making it suitable for gunicorn --preload workers.
"""
if isinstance(driver_args, str):
driver_args = [ driver_args ]
Expand All @@ -410,7 +505,9 @@ def connect(jclassname, url, driver_args=None, jars=None, libs=None):
libs = [ libs ]
else:
libs = []
jconn = _jdbc_connect(jclassname, url, driver_args, jars, libs)
if experimental is None:
experimental = {}
jconn = _jdbc_connect(jclassname, url, driver_args, jars, libs, experimental=experimental)
return Connection(jconn, jclassname)

# DB-API 2.0 Connection Object
Expand Down
161 changes: 161 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,3 +2007,164 @@ def test_hsqldb_jar_path_with_spaces(self):
)
self.assertTrue(result.stdout.strip().startswith('OK'),
f'Connection failed: {result.stdout}\n{result.stderr}')


class ForkSafetyTest(unittest.TestCase):
"""Tests for fork-safety guard (legacy issue #232)."""

def test_fork_after_connect_raises_interface_error(self):
"""Simulating a fork by overwriting the PID tracker must raise
InterfaceError when attempting a new connection."""
import os
original_pid = jaydebeapiarrow._jvm_started_pid
try:
jaydebeapiarrow._jvm_started_pid = os.getpid() + 99999
with self.assertRaises(jaydebeapiarrow.InterfaceError) as ctx:
jaydebeapiarrow.connect('org.hsqldb.jdbcDriver',
'jdbc:hsqldb:mem:.', ['SA', ''])
self.assertIn("forked process", str(ctx.exception))
finally:
jaydebeapiarrow._jvm_started_pid = original_pid

def test_pid_recorded_after_connect(self):
"""After connect(), _jvm_started_pid must equal the current PID."""
import os
c = jaydebeapiarrow.connect('org.hsqldb.jdbcDriver',
'jdbc:hsqldb:mem:.', ['SA', ''])
try:
self.assertEqual(jaydebeapiarrow._jvm_started_pid, os.getpid())
finally:
c.close()


class DynamicClasspathIntegrationTest(unittest.TestCase):
"""Tests for experimental dynamic_classpath feature with real JDBC driver."""

def _find_hsqldb_jar(self):
jar_dir = os.path.join(_THIS_DIR, 'jars')
if not os.path.isdir(jar_dir):
self.skipTest('test/jars/ directory not found (run download_jdbc_drivers.sh)')
for f in os.listdir(jar_dir):
if 'hsqldb' in f.lower() and f.endswith('.jar'):
return os.path.join(jar_dir, f)
self.skipTest('HSQLDB JAR not found in test/jars/')

def _find_mock_jar(self):
for root, dirs, files in os.walk(_THIS_DIR):
for f in files:
if f.startswith('mockdriver') and f.endswith('.jar'):
return os.path.join(root, f)
self.skipTest('mockdriver JAR not found')

def _run_in_subprocess(self, code):
return subprocess.run(
[sys.executable, '-c', code],
capture_output=True, text=True, timeout=30,
cwd=os.path.dirname(_THIS_DIR)
)

def test_hsqldb_fails_without_dynamic_classpath(self):
"""Connecting to HSQLDB after JVM starts with only mock driver on classpath
should fail — the HSQLDB driver is not available."""
hsqldb_jar = self._find_hsqldb_jar()
mock_jar = self._find_mock_jar()

# Start JVM with CLASSPATH pointing only to mock JAR (no HSQLDB)
env = {**os.environ, 'CLASSPATH': mock_jar}
code = f'''
import jaydebeapiarrow

# Start JVM with only the mock driver available
conn1 = jaydebeapiarrow.connect(
'org.jaydebeapi.mockdriver.MockDriver',
'jdbc:jaydebeapi://dummyurl'
)
conn1.close()

# Try to connect to HSQLDB without dynamic classpath — should fail
# because HSQLDB driver was never loaded
try:
conn2 = jaydebeapiarrow.connect(
'org.hsqldb.jdbcDriver',
'jdbc:hsqldb:mem:.',
['SA', '']
)
conn2.close()
print('UNEXPECTED_SUCCESS')
except Exception as e:
print(f'EXPECTED_FAIL: {{type(e).__name__}}')
'''
result = subprocess.run(
[sys.executable, '-c', code],
capture_output=True, text=True, timeout=30,
cwd=os.path.dirname(_THIS_DIR),
env=env
)
self.assertTrue(result.stdout.strip().startswith('EXPECTED_FAIL'),
f'HSQLDB should fail without dynamic classpath.\n'
f'stdout: {result.stdout}\nstderr: {result.stderr}')

def test_dynamic_load_hsqldb_after_jvm_start(self):
"""Dynamically load HSQLDB driver after JVM is already running.
Starts JVM with only the mock driver, then loads HSQLDB from JAR."""
hsqldb_jar = self._find_hsqldb_jar()
mock_jar = self._find_mock_jar()

# Start JVM with CLASSPATH pointing only to mock JAR (no HSQLDB)
env = {**os.environ, 'CLASSPATH': mock_jar}
code = f'''
import jaydebeapiarrow

# Start JVM with only the mock driver on the classpath
conn1 = jaydebeapiarrow.connect(
'org.jaydebeapi.mockdriver.MockDriver',
'jdbc:jaydebeapi://dummyurl'
)
conn1.close()

# Verify HSQLDB is NOT available yet
try:
conn_bad = jaydebeapiarrow.connect(
'org.hsqldb.jdbcDriver',
'jdbc:hsqldb:mem:.',
['SA', '']
)
conn_bad.close()
print('HSQQLDB_AVAILABLE_WITHOUT_DYNAMIC')
except Exception:
print('HSQQLDB_NOT_AVAILABLE')

# Now dynamically load HSQLDB driver from JAR
conn2 = jaydebeapiarrow.connect(
'org.hsqldb.jdbcDriver',
'jdbc:hsqldb:mem:.',
['SA', ''],
jars={repr(hsqldb_jar)},
experimental={{'dynamic_classpath': True}}
)
cursor = conn2.cursor()

# Verify it actually works — run real SQL
cursor.execute('CREATE TABLE test_dynamic (id INTEGER, name VARCHAR(50))')
cursor.execute("INSERT INTO test_dynamic VALUES (1, 'hello'), (2, 'world')")
cursor.execute('SELECT id, name FROM test_dynamic ORDER BY id')
rows = cursor.fetchall()
cursor.execute('DROP TABLE test_dynamic')
cursor.close()
conn2.close()

print(f'DYNAMIC_OK: {{rows}}')
'''
result = subprocess.run(
[sys.executable, '-c', code],
capture_output=True, text=True, timeout=30,
cwd=os.path.dirname(_THIS_DIR),
env=env
)
lines = result.stdout.strip().split('\n')
self.assertEqual(lines[0], 'HSQQLDB_NOT_AVAILABLE',
f'HSQLDB should not be available before dynamic load.\n'
f'stdout: {result.stdout}\nstderr: {result.stderr}')
self.assertEqual(lines[1], 'DYNAMIC_OK: [(1, \'hello\'), (2, \'world\')]',
f'Dynamic HSQLDB load failed or returned wrong data.\n'
f'stdout: {result.stdout}\nstderr: {result.stderr}')
Loading
Loading