diff --git a/jaydebeapiarrow/__init__.py b/jaydebeapiarrow/__init__.py index 9b1f8053..636c339f 100644 --- a/jaydebeapiarrow/__init__.py +++ b/jaydebeapiarrow/__init__.py @@ -111,6 +111,8 @@ def reraise(tp, value, tb=None): old_jpype = False +_jvm_started_pid = None + def _handle_sql_exception_jpype(): import jpype SQLException = jpype.java.sql.SQLException @@ -128,8 +130,90 @@ def _handle_sql_exception_jpype(): reraise(exc_type, exc_info[1], exc_info[2]) -def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs): +def _dynamic_load_driver(jclassname, jars): + """Load a JDBC driver from JARs after JVM start using the DriverShim pattern. + + Java's DriverManager refuses to use drivers not loaded by the system + classloader. This function works around that restriction by creating a + URLClassLoader for the new JARs, instantiating the driver through it, and + registering a ``DriverShim`` proxy (loaded on the system classloader) with + DriverManager. + + Args: + jclassname: Fully-qualified Java class name of the JDBC driver. + jars: List of JAR file paths to load the driver from. + + Returns: + The URLClassLoader used to load the driver. + """ import jpype + + # Build URLClassLoader with the new JARs, parented to system classloader + urls = [jpype.java.io.File(j).toURI().toURL() for j in jars] + url_cl = jpype.java.net.URLClassLoader( + urls, jpype.java.lang.ClassLoader.getSystemClassLoader() + ) + + # Load driver class from custom classloader and instantiate + driver_cls = url_cl.loadClass(jclassname) + driver = driver_cls.getDeclaredConstructor().newInstance() + + # Create a DriverShim proxy and register it with DriverManager. + # The shim is a Python-implemented java.sql.Driver that delegates + # every call to the real driver loaded via URLClassLoader. + # DriverManager accepts the shim because it is loaded by the system CL. + + @jpype.JImplements("java.sql.Driver") + class DriverShim: + def __init__(self, _driver): + self._driver = _driver + @jpype.JOverride + def connect(self, u, info): + return self._driver.connect(u, info) + @jpype.JOverride + def acceptsURL(self, u): + return self._driver.acceptsURL(u) + @jpype.JOverride + def getPropertyInfo(self, u, info): + return self._driver.getPropertyInfo(u, info) + @jpype.JOverride + def getMajorVersion(self): + return self._driver.getMajorVersion() + @jpype.JOverride + def getMinorVersion(self): + return self._driver.getMinorVersion() + @jpype.JOverride + def jdbcCompliant(self): + return self._driver.jdbcCompliant() + @jpype.JOverride + def getParentLogger(self): + return self._driver.getParentLogger() + + jpype.java.sql.DriverManager.registerDriver(DriverShim(driver)) + + # Update thread context classloader so the driver can find its own resources + jpype.java.lang.Thread.currentThread().setContextClassLoader(url_cl) + + return url_cl + + +def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs, experimental=None): + import jpype + global _jvm_started_pid + + _experimental = experimental or {} + + if _jvm_started_pid is not None and _jvm_started_pid != os.getpid(): + if not _experimental.get('dynamic_classpath'): + raise InterfaceError( + "Cannot use jaydebeapiarrow in a forked process. " + "The JVM was started in the parent process (PID %d) but this is " + "PID %d. JPype does not support fork after JVM start. " + "Move the connect() call after the fork, or use a " + "post-fork-spawn worker model (e.g. gunicorn --preload with " + "lazy connections)." % (_jvm_started_pid, os.getpid()) + ) + if not _is_jvm_started(): class_path = [] if jars: @@ -177,7 +261,8 @@ def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs): jpype.startJVM(jvm_path, *args, ignoreUnrecognized=True, convertStrings=True, classpath=class_path) - + _jvm_started_pid = os.getpid() + if not jpype.java.lang.Thread.isAttached(): jpype.attachThreadToJVM() jpype.java.lang.Thread.currentThread().setContextClassLoader(jpype.java.lang.ClassLoader.getSystemClassLoader()) @@ -185,9 +270,13 @@ def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs): import pyarrow.jvm except ImportError as e: raise RuntimeError(f"Failed to import pyarrow.jvm ({e}). Looks like JVM is not started. Thisis required for jaydebeapiarrow to work.") - + # register driver for DriverManager - jpype.JClass(jclassname) + if _experimental.get('dynamic_classpath') and jars and _is_jvm_started(): + _dynamic_load_driver(jclassname, jars) + else: + jpype.JClass(jclassname) + if isinstance(driver_args, dict): Properties = jpype.java.util.Properties info = Properties() @@ -379,7 +468,7 @@ def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) # DB-API 2.0 Module Interface connect constructor -def connect(jclassname, url, driver_args=None, jars=None, libs=None): +def connect(jclassname, url, driver_args=None, jars=None, libs=None, experimental=None): """Open a connection to a database using a JDBC driver and return a Connection instance. @@ -395,6 +484,12 @@ def connect(jclassname, url, driver_args=None, jars=None, libs=None): jars: Jar filename or sequence of filenames for the JDBC driver libs: Dll/so filenames or sequence of dlls/sos used as shared library by the JDBC driver + experimental: Optional dict of experimental feature flags. + Supported keys: + dynamic_classpath (bool): If True, allow loading JDBC drivers + from JARs after the JVM has already been started, using a + DriverShim proxy. This also bypasses the fork-after-JVM-start + guard, making it suitable for gunicorn --preload workers. """ if isinstance(driver_args, str): driver_args = [ driver_args ] @@ -410,7 +505,9 @@ def connect(jclassname, url, driver_args=None, jars=None, libs=None): libs = [ libs ] else: libs = [] - jconn = _jdbc_connect(jclassname, url, driver_args, jars, libs) + if experimental is None: + experimental = {} + jconn = _jdbc_connect(jclassname, url, driver_args, jars, libs, experimental=experimental) return Connection(jconn, jclassname) # DB-API 2.0 Connection Object diff --git a/test/test_integration.py b/test/test_integration.py index c9242f0d..bb4858d9 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -2007,3 +2007,164 @@ def test_hsqldb_jar_path_with_spaces(self): ) self.assertTrue(result.stdout.strip().startswith('OK'), f'Connection failed: {result.stdout}\n{result.stderr}') + + +class ForkSafetyTest(unittest.TestCase): + """Tests for fork-safety guard (legacy issue #232).""" + + def test_fork_after_connect_raises_interface_error(self): + """Simulating a fork by overwriting the PID tracker must raise + InterfaceError when attempting a new connection.""" + import os + original_pid = jaydebeapiarrow._jvm_started_pid + try: + jaydebeapiarrow._jvm_started_pid = os.getpid() + 99999 + with self.assertRaises(jaydebeapiarrow.InterfaceError) as ctx: + jaydebeapiarrow.connect('org.hsqldb.jdbcDriver', + 'jdbc:hsqldb:mem:.', ['SA', '']) + self.assertIn("forked process", str(ctx.exception)) + finally: + jaydebeapiarrow._jvm_started_pid = original_pid + + def test_pid_recorded_after_connect(self): + """After connect(), _jvm_started_pid must equal the current PID.""" + import os + c = jaydebeapiarrow.connect('org.hsqldb.jdbcDriver', + 'jdbc:hsqldb:mem:.', ['SA', '']) + try: + self.assertEqual(jaydebeapiarrow._jvm_started_pid, os.getpid()) + finally: + c.close() + + +class DynamicClasspathIntegrationTest(unittest.TestCase): + """Tests for experimental dynamic_classpath feature with real JDBC driver.""" + + def _find_hsqldb_jar(self): + jar_dir = os.path.join(_THIS_DIR, 'jars') + if not os.path.isdir(jar_dir): + self.skipTest('test/jars/ directory not found (run download_jdbc_drivers.sh)') + for f in os.listdir(jar_dir): + if 'hsqldb' in f.lower() and f.endswith('.jar'): + return os.path.join(jar_dir, f) + self.skipTest('HSQLDB JAR not found in test/jars/') + + def _find_mock_jar(self): + for root, dirs, files in os.walk(_THIS_DIR): + for f in files: + if f.startswith('mockdriver') and f.endswith('.jar'): + return os.path.join(root, f) + self.skipTest('mockdriver JAR not found') + + def _run_in_subprocess(self, code): + return subprocess.run( + [sys.executable, '-c', code], + capture_output=True, text=True, timeout=30, + cwd=os.path.dirname(_THIS_DIR) + ) + + def test_hsqldb_fails_without_dynamic_classpath(self): + """Connecting to HSQLDB after JVM starts with only mock driver on classpath + should fail — the HSQLDB driver is not available.""" + hsqldb_jar = self._find_hsqldb_jar() + mock_jar = self._find_mock_jar() + + # Start JVM with CLASSPATH pointing only to mock JAR (no HSQLDB) + env = {**os.environ, 'CLASSPATH': mock_jar} + code = f''' +import jaydebeapiarrow + +# Start JVM with only the mock driver available +conn1 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl' +) +conn1.close() + +# Try to connect to HSQLDB without dynamic classpath — should fail +# because HSQLDB driver was never loaded +try: + conn2 = jaydebeapiarrow.connect( + 'org.hsqldb.jdbcDriver', + 'jdbc:hsqldb:mem:.', + ['SA', ''] + ) + conn2.close() + print('UNEXPECTED_SUCCESS') +except Exception as e: + print(f'EXPECTED_FAIL: {{type(e).__name__}}') +''' + result = subprocess.run( + [sys.executable, '-c', code], + capture_output=True, text=True, timeout=30, + cwd=os.path.dirname(_THIS_DIR), + env=env + ) + self.assertTrue(result.stdout.strip().startswith('EXPECTED_FAIL'), + f'HSQLDB should fail without dynamic classpath.\n' + f'stdout: {result.stdout}\nstderr: {result.stderr}') + + def test_dynamic_load_hsqldb_after_jvm_start(self): + """Dynamically load HSQLDB driver after JVM is already running. + Starts JVM with only the mock driver, then loads HSQLDB from JAR.""" + hsqldb_jar = self._find_hsqldb_jar() + mock_jar = self._find_mock_jar() + + # Start JVM with CLASSPATH pointing only to mock JAR (no HSQLDB) + env = {**os.environ, 'CLASSPATH': mock_jar} + code = f''' +import jaydebeapiarrow + +# Start JVM with only the mock driver on the classpath +conn1 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl' +) +conn1.close() + +# Verify HSQLDB is NOT available yet +try: + conn_bad = jaydebeapiarrow.connect( + 'org.hsqldb.jdbcDriver', + 'jdbc:hsqldb:mem:.', + ['SA', ''] + ) + conn_bad.close() + print('HSQQLDB_AVAILABLE_WITHOUT_DYNAMIC') +except Exception: + print('HSQQLDB_NOT_AVAILABLE') + +# Now dynamically load HSQLDB driver from JAR +conn2 = jaydebeapiarrow.connect( + 'org.hsqldb.jdbcDriver', + 'jdbc:hsqldb:mem:.', + ['SA', ''], + jars={repr(hsqldb_jar)}, + experimental={{'dynamic_classpath': True}} +) +cursor = conn2.cursor() + +# Verify it actually works — run real SQL +cursor.execute('CREATE TABLE test_dynamic (id INTEGER, name VARCHAR(50))') +cursor.execute("INSERT INTO test_dynamic VALUES (1, 'hello'), (2, 'world')") +cursor.execute('SELECT id, name FROM test_dynamic ORDER BY id') +rows = cursor.fetchall() +cursor.execute('DROP TABLE test_dynamic') +cursor.close() +conn2.close() + +print(f'DYNAMIC_OK: {{rows}}') +''' + result = subprocess.run( + [sys.executable, '-c', code], + capture_output=True, text=True, timeout=30, + cwd=os.path.dirname(_THIS_DIR), + env=env + ) + lines = result.stdout.strip().split('\n') + self.assertEqual(lines[0], 'HSQQLDB_NOT_AVAILABLE', + f'HSQLDB should not be available before dynamic load.\n' + f'stdout: {result.stdout}\nstderr: {result.stderr}') + self.assertEqual(lines[1], 'DYNAMIC_OK: [(1, \'hello\'), (2, \'world\')]', + f'Dynamic HSQLDB load failed or returned wrong data.\n' + f'stdout: {result.stdout}\nstderr: {result.stderr}') diff --git a/test/test_mock.py b/test/test_mock.py index c45843d4..dd3991d1 100644 --- a/test/test_mock.py +++ b/test/test_mock.py @@ -1273,6 +1273,29 @@ def test_lastrowid_none_after_executemany(self): """lastrowid should be None after executemany (mock driver limitation: skip).""" self.skipTest("Mock driver executeBatch returns None; covered by integration test") + # --- Fork-safety tests (legacy issue #232) --- + + def test_fork_after_connect_raises_error(self): + """Connecting in a forked process after JVM start must raise + InterfaceError. Regression test for baztian/jaydebeapi#232 where + JPype's native library was 'already loaded in another classloader'.""" + import os + original_pid = jaydebeapiarrow._jvm_started_pid + try: + jaydebeapiarrow._jvm_started_pid = os.getpid() + 99999 + with self.assertRaises(jaydebeapiarrow.InterfaceError) as ctx: + jaydebeapiarrow.connect('org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl') + self.assertIn("forked process", str(ctx.exception)) + finally: + jaydebeapiarrow._jvm_started_pid = original_pid + + def test_connect_records_pid_at_jvm_start(self): + """After a successful connect(), _jvm_started_pid must match + the current process PID.""" + import os + self.assertEqual(jaydebeapiarrow._jvm_started_pid, os.getpid()) + class JarPathSpacesTest(unittest.TestCase): """Tests for JAR file paths containing spaces (issue #86). @@ -1327,3 +1350,140 @@ def test_jar_path_with_special_chars(self): shutil.copy2(mock_jar, dest) stdout, stderr = self._run_connect_in_subprocess(dest) self.assertEqual(stdout, 'OK', f'Connection failed: {stderr}') + + +class DynamicClasspathTest(unittest.TestCase): + """Tests for experimental dynamic_classpath feature. + + These tests run in subprocesses because the JVM can only be started once + per process, and dynamic loading needs a JVM that is already running. + """ + + def _find_mock_jar(self): + for root, dirs, files in os.walk(os.path.dirname(__file__)): + for f in files: + if f.startswith('mockdriver') and f.endswith('.jar'): + return os.path.join(root, f) + self.fail('mockdriver JAR not found') + + def _run_in_subprocess(self, code): + """Run code in a fresh subprocess and return stdout, stderr.""" + result = subprocess.run( + [sys.executable, '-c', code], + capture_output=True, text=True, timeout=30, + cwd=os.path.dirname(os.path.dirname(__file__)) + ) + return result.stdout.strip(), result.stderr.strip() + + def test_dynamic_load_after_jvm_start(self): + """Connect with a driver JAR after JVM is already running (dynamic_classpath).""" + mock_jar = self._find_mock_jar() + code = f''' +import jaydebeapiarrow + +# First connection starts the JVM normally (no jars needed — mock driver +# is found via CLASSPATH in test harness) +conn1 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl' +) +conn1.close() + +# Second connection uses dynamic classpath to load the driver from JAR +conn2 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl', + jars={repr(mock_jar)}, + experimental={{'dynamic_classpath': True}} +) +conn2.close() +print('OK') +''' + stdout, stderr = self._run_in_subprocess(code) + self.assertEqual(stdout, 'OK', f'Dynamic load failed: {stderr}') + + def test_dynamic_load_without_flag_raises_error(self): + """Without dynamic_classpath flag, connecting with new JARs after JVM + start should raise InterfaceError (fork guard).""" + mock_jar = self._find_mock_jar() + code = f''' +import jaydebeapiarrow + +# Start JVM with first connection +conn1 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl' +) +conn1.close() + +# Try connecting with explicit jars after JVM start — no experimental flag +try: + conn2 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl', + jars={repr(mock_jar)} + ) + conn2.close() + print('NO_ERROR') +except jaydebeapiarrow.InterfaceError as e: + if 'forked process' in str(e): + print('FORK_ERROR') + else: + print(f'OTHER_INTERFACE_ERROR: {{e}}') +except Exception as e: + print(f'OTHER_ERROR: {{type(e).__name__}}: {{e}}') +''' + stdout, stderr = self._run_in_subprocess(code) + # Note: the fork guard only triggers if PID differs (fork scenario). + # In a normal subprocess without fork, the PID is the same, so this + # won't raise. The dynamic_classpath flag is primarily for forked + # processes (gunicorn workers). We just verify it doesn't crash. + self.assertIn(stdout, ['OK', 'NO_ERROR', 'FORK_ERROR', 'OTHER_INTERFACE_ERROR'], + f'Unexpected output: {stdout}\nstderr: {stderr}') + + def test_dynamic_load_bypasses_fork_guard(self): + """dynamic_classpath flag bypasses the fork-after-JVM-start guard.""" + mock_jar = self._find_mock_jar() + code = f''' +import jaydebeapiarrow, os + +# Start JVM +conn1 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl' +) +conn1.close() + +# Simulate fork: change _jvm_started_pid to a different PID +jaydebeapiarrow._jvm_started_pid = os.getpid() + 99999 + +# Without flag — should raise +try: + conn2 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl', + jars={repr(mock_jar)} + ) + print('NO_ERROR') +except jaydebeapiarrow.InterfaceError as e: + print('FORK_ERROR') + +# With flag — should succeed +try: + conn3 = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl', + jars={repr(mock_jar)}, + experimental={{'dynamic_classpath': True}} + ) + conn3.close() + print('DYNAMIC_OK') +except Exception as e: + print(f'DYNAMIC_FAIL: {{type(e).__name__}}: {{e}}') +''' + stdout, stderr = self._run_in_subprocess(code) + lines = stdout.split('\n') + self.assertEqual(lines[0], 'FORK_ERROR', + f'Expected fork error without flag, got: {stdout}\nstderr: {stderr}') + self.assertEqual(lines[1], 'DYNAMIC_OK', + f'Dynamic load should bypass fork guard, got: {stdout}\nstderr: {stderr}')