diff --git a/jaydebeapiarrow/__init__.py b/jaydebeapiarrow/__init__.py index 9b1f8053..4d9ccd9c 100644 --- a/jaydebeapiarrow/__init__.py +++ b/jaydebeapiarrow/__init__.py @@ -32,6 +32,7 @@ from decimal import Decimal import glob import os +import threading import time import sys import warnings @@ -111,6 +112,15 @@ def reraise(tp, value, tb=None): old_jpype = False +# Flag and lock to prevent race condition when multiple threads call +# connect() simultaneously — without this, two threads can both see +# isJVMStarted() return False and both attempt to start the JVM, +# causing a crash. The lock is only held briefly to check/set the +# flag; startJVM() runs _outside_ the lock to avoid potential +# deadlocks if JPype spawns threads during initialisation. +_jvm_startup_lock = threading.Lock() +_jvm_starting = False + def _handle_sql_exception_jpype(): import jpype SQLException = jpype.java.sql.SQLException @@ -130,54 +140,83 @@ def _handle_sql_exception_jpype(): def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs): import jpype - if not _is_jvm_started(): - class_path = [] - if jars: - class_path.extend(jars) - class_path.extend(_get_classpath()) - class_path.extend(_get_arrow_jar_paths()) - class_path = list(set(class_path)) - - args = [] - - if libs: - # path to shared libraries - libs_path = os.path.pathsep.join(libs) - args.append('-Djava.library.path=%s' % libs_path) - - # Known issue: some JDBC drivers (notably IBM Db2) use the JVM's - # default charset for string conversion. When the default is not - # UTF-8, non-ASCII characters (German umlauts, CJK, emoji) cause - # CharConversionException during result-set traversal. Users who - # encounter this should pass jvm_args=['-Dfile.encoding=UTF-8'] - # when calling connect(). - # TODO: document this encoding requirement in user-facing docs - # and consider exposing a dedicated encoding parameter in connect(). - - # Add-opens for Apache Arrow on Java 9+ - args.append('--add-opens=java.base/java.nio=ALL-UNNAMED') - - # jvm_path = ('/usr/lib/jvm/java-6-openjdk' - # '/jre/lib/i386/client/libjvm.so') - jvm_path = jpype.getDefaultJVMPath() - global old_jpype - if hasattr(jpype, '__version__'): - try: - ver_match = re.match(r'\d+\.\d+', jpype.__version__) - if ver_match: - jpype_ver = float(ver_match.group(0)) - if jpype_ver < 0.7: - old_jpype = True - except ValueError: - pass - if old_jpype: - jpype.startJVM(jvm_path, *args, - classpath=class_path) + global _jvm_starting + + # Brief lock: decide who starts the JVM (if needed). + with _jvm_startup_lock: + if _is_jvm_started(): + should_start = False + elif _jvm_starting: + # Another thread is already starting the JVM; wait for it. + should_start = False else: - jpype.startJVM(jvm_path, *args, ignoreUnrecognized=True, - convertStrings=True, - classpath=class_path) - + _jvm_starting = True + should_start = True + + if should_start: + try: + class_path = [] + if jars: + class_path.extend(jars) + class_path.extend(_get_classpath()) + class_path.extend(_get_arrow_jar_paths()) + class_path = list(set(class_path)) + + args = [] + + if libs: + # path to shared libraries + libs_path = os.path.pathsep.join(libs) + args.append('-Djava.library.path=%s' % libs_path) + + # Known issue: some JDBC drivers (notably IBM Db2) use the JVM's + # default charset for string conversion. When the default is not + # UTF-8, non-ASCII characters (German umlauts, CJK, emoji) cause + # CharConversionException during result-set traversal. Users who + # encounter this should pass jvm_args=['-Dfile.encoding=UTF-8'] + # when calling connect(). + # TODO: document this encoding requirement in user-facing docs + # and consider exposing a dedicated encoding parameter in connect(). + + # Add-opens for Apache Arrow on Java 9+ + args.append('--add-opens=java.base/java.nio=ALL-UNNAMED') + + # jvm_path = ('/usr/lib/jvm/java-6-openjdk' + # '/jre/lib/i386/client/libjvm.so') + jvm_path = jpype.getDefaultJVMPath() + global old_jpype + if hasattr(jpype, '__version__'): + try: + ver_match = re.match(r'\d+\.\d+', jpype.__version__) + if ver_match: + jpype_ver = float(ver_match.group(0)) + if jpype_ver < 0.7: + old_jpype = True + except ValueError: + pass + if old_jpype: + jpype.startJVM(jvm_path, *args, + classpath=class_path) + else: + jpype.startJVM(jvm_path, *args, ignoreUnrecognized=True, + convertStrings=True, + classpath=class_path) + finally: + with _jvm_startup_lock: + _jvm_starting = False + elif not _is_jvm_started(): + # Another thread is starting the JVM; spin-wait until ready. + waited = 0 + while not _is_jvm_started(): + if not _jvm_starting: + # Startup thread failed; bail out so the caller sees the + # original exception (or retries on the next connect()). + break + time.sleep(0.05) + waited += 0.05 + if waited > 120: + raise RuntimeError("Timed out waiting for JVM to start") + if not jpype.java.lang.Thread.isAttached(): jpype.attachThreadToJVM() jpype.java.lang.Thread.currentThread().setContextClassLoader(jpype.java.lang.ClassLoader.getSystemClassLoader()) diff --git a/test/test_integration.py b/test/test_integration.py index c9242f0d..79861bfc 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -32,6 +32,7 @@ import sys import tempfile import threading +from functools import partial import unittest @@ -2007,3 +2008,39 @@ def test_hsqldb_jar_path_with_spaces(self): ) self.assertTrue(result.stdout.strip().startswith('OK'), f'Connection failed: {result.stdout}\n{result.stderr}') + + +class ParallelConnectTest(unittest.TestCase): + """Test that parallel connect() calls are thread-safe (issue #60).""" + + def test_parallel_connects_with_hsqldb(self): + """Multiple threads connecting simultaneously should not crash.""" + errors = [] + + def connect_and_query(idx): + import jpype + try: + conn = jaydebeapiarrow.connect( + 'org.hsqldb.jdbcDriver', + 'jdbc:hsqldb:mem:parallel%d' % idx, + ['SA', '']) + cursor = conn.cursor() + cursor.execute("SELECT 1 FROM (VALUES(0))") + rows = cursor.fetchall() + conn.close() + except Exception as e: + errors.append(e) + finally: + if jpype.isThreadAttachedToJVM(): + jpype.detachThreadFromJVM() + + threads = [] + for i in range(5): + t = threading.Thread(target=partial(connect_and_query, i)) + threads.append(t) + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, [], f"Thread errors: {errors}") diff --git a/test/test_mock.py b/test/test_mock.py index c45843d4..920533fb 100644 --- a/test/test_mock.py +++ b/test/test_mock.py @@ -25,6 +25,8 @@ import subprocess import sys import tempfile +import threading +from functools import partial try: import unittest2 as unittest @@ -1327,3 +1329,43 @@ def test_jar_path_with_special_chars(self): shutil.copy2(mock_jar, dest) stdout, stderr = self._run_connect_in_subprocess(dest) self.assertEqual(stdout, 'OK', f'Connection failed: {stderr}') + + +class ParallelConnectTest(unittest.TestCase): + """Test that parallel connect() calls are thread-safe (issue #60).""" + + def test_parallel_connects_after_jvm_started(self): + """Multiple threads connecting simultaneously should not crash.""" + errors = [] + + def connect_thread(idx): + import jpype + try: + conn = jaydebeapiarrow.connect( + 'org.jaydebeapi.mockdriver.MockDriver', + 'jdbc:jaydebeapi://dummyurl%d' % idx) + # Verify the connection is usable + self.assertIsNotNone(conn) + conn.close() + except Exception as e: + errors.append(e) + finally: + if jpype.isThreadAttachedToJVM(): + jpype.detachThreadFromJVM() + + threads = [] + for i in range(5): + t = threading.Thread(target=partial(connect_thread, i)) + threads.append(t) + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, [], f"Thread errors: {errors}") + + def test_jvm_startup_lock_exists(self): + """The _jvm_startup_lock should be a threading.Lock.""" + self.assertTrue(hasattr(jaydebeapiarrow, '_jvm_startup_lock')) + self.assertIsInstance(jaydebeapiarrow._jvm_startup_lock, type(threading.Lock())) + self.assertTrue(hasattr(jaydebeapiarrow, '_jvm_starting'))