Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 86 additions & 47 deletions jaydebeapiarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from decimal import Decimal
import glob
import os
import threading
import time
import sys
import warnings
Expand Down Expand Up @@ -111,6 +112,15 @@ def reraise(tp, value, tb=None):

old_jpype = False

# Flag and lock to prevent race condition when multiple threads call
# connect() simultaneously — without this, two threads can both see
# isJVMStarted() return False and both attempt to start the JVM,
# causing a crash. The lock is only held briefly to check/set the
# flag; startJVM() runs _outside_ the lock to avoid potential
# deadlocks if JPype spawns threads during initialisation.
_jvm_startup_lock = threading.Lock()
_jvm_starting = False

def _handle_sql_exception_jpype():
import jpype
SQLException = jpype.java.sql.SQLException
Expand All @@ -130,54 +140,83 @@ def _handle_sql_exception_jpype():

def _jdbc_connect_jpype(jclassname, url, driver_args, jars, libs):
import jpype
if not _is_jvm_started():
class_path = []
if jars:
class_path.extend(jars)
class_path.extend(_get_classpath())
class_path.extend(_get_arrow_jar_paths())
class_path = list(set(class_path))

args = []

if libs:
# path to shared libraries
libs_path = os.path.pathsep.join(libs)
args.append('-Djava.library.path=%s' % libs_path)

# Known issue: some JDBC drivers (notably IBM Db2) use the JVM's
# default charset for string conversion. When the default is not
# UTF-8, non-ASCII characters (German umlauts, CJK, emoji) cause
# CharConversionException during result-set traversal. Users who
# encounter this should pass jvm_args=['-Dfile.encoding=UTF-8']
# when calling connect().
# TODO: document this encoding requirement in user-facing docs
# and consider exposing a dedicated encoding parameter in connect().

# Add-opens for Apache Arrow on Java 9+
args.append('--add-opens=java.base/java.nio=ALL-UNNAMED')

# jvm_path = ('/usr/lib/jvm/java-6-openjdk'
# '/jre/lib/i386/client/libjvm.so')
jvm_path = jpype.getDefaultJVMPath()
global old_jpype
if hasattr(jpype, '__version__'):
try:
ver_match = re.match(r'\d+\.\d+', jpype.__version__)
if ver_match:
jpype_ver = float(ver_match.group(0))
if jpype_ver < 0.7:
old_jpype = True
except ValueError:
pass
if old_jpype:
jpype.startJVM(jvm_path, *args,
classpath=class_path)
global _jvm_starting

# Brief lock: decide who starts the JVM (if needed).
with _jvm_startup_lock:
if _is_jvm_started():
should_start = False
elif _jvm_starting:
# Another thread is already starting the JVM; wait for it.
should_start = False
else:
jpype.startJVM(jvm_path, *args, ignoreUnrecognized=True,
convertStrings=True,
classpath=class_path)

_jvm_starting = True
should_start = True

if should_start:
try:
class_path = []
if jars:
class_path.extend(jars)
class_path.extend(_get_classpath())
class_path.extend(_get_arrow_jar_paths())
class_path = list(set(class_path))

args = []

if libs:
# path to shared libraries
libs_path = os.path.pathsep.join(libs)
args.append('-Djava.library.path=%s' % libs_path)

# Known issue: some JDBC drivers (notably IBM Db2) use the JVM's
# default charset for string conversion. When the default is not
# UTF-8, non-ASCII characters (German umlauts, CJK, emoji) cause
# CharConversionException during result-set traversal. Users who
# encounter this should pass jvm_args=['-Dfile.encoding=UTF-8']
# when calling connect().
# TODO: document this encoding requirement in user-facing docs
# and consider exposing a dedicated encoding parameter in connect().

# Add-opens for Apache Arrow on Java 9+
args.append('--add-opens=java.base/java.nio=ALL-UNNAMED')

# jvm_path = ('/usr/lib/jvm/java-6-openjdk'
# '/jre/lib/i386/client/libjvm.so')
jvm_path = jpype.getDefaultJVMPath()
global old_jpype
if hasattr(jpype, '__version__'):
try:
ver_match = re.match(r'\d+\.\d+', jpype.__version__)
if ver_match:
jpype_ver = float(ver_match.group(0))
if jpype_ver < 0.7:
old_jpype = True
except ValueError:
pass
if old_jpype:
jpype.startJVM(jvm_path, *args,
classpath=class_path)
else:
jpype.startJVM(jvm_path, *args, ignoreUnrecognized=True,
convertStrings=True,
classpath=class_path)
finally:
with _jvm_startup_lock:
_jvm_starting = False
elif not _is_jvm_started():
# Another thread is starting the JVM; spin-wait until ready.
waited = 0
while not _is_jvm_started():
if not _jvm_starting:
# Startup thread failed; bail out so the caller sees the
# original exception (or retries on the next connect()).
break
time.sleep(0.05)
waited += 0.05
if waited > 120:
raise RuntimeError("Timed out waiting for JVM to start")

if not jpype.java.lang.Thread.isAttached():
jpype.attachThreadToJVM()
jpype.java.lang.Thread.currentThread().setContextClassLoader(jpype.java.lang.ClassLoader.getSystemClassLoader())
Expand Down
37 changes: 37 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import sys
import tempfile
import threading
from functools import partial

import unittest

Expand Down Expand Up @@ -2007,3 +2008,39 @@ def test_hsqldb_jar_path_with_spaces(self):
)
self.assertTrue(result.stdout.strip().startswith('OK'),
f'Connection failed: {result.stdout}\n{result.stderr}')


class ParallelConnectTest(unittest.TestCase):
"""Test that parallel connect() calls are thread-safe (issue #60)."""

def test_parallel_connects_with_hsqldb(self):
"""Multiple threads connecting simultaneously should not crash."""
errors = []

def connect_and_query(idx):
import jpype
try:
conn = jaydebeapiarrow.connect(
'org.hsqldb.jdbcDriver',
'jdbc:hsqldb:mem:parallel%d' % idx,
['SA', ''])
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM (VALUES(0))")
rows = cursor.fetchall()
conn.close()
except Exception as e:
errors.append(e)
finally:
if jpype.isThreadAttachedToJVM():
jpype.detachThreadFromJVM()

threads = []
for i in range(5):
t = threading.Thread(target=partial(connect_and_query, i))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(errors, [], f"Thread errors: {errors}")
42 changes: 42 additions & 0 deletions test/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import subprocess
import sys
import tempfile
import threading
from functools import partial

try:
import unittest2 as unittest
Expand Down Expand Up @@ -1327,3 +1329,43 @@ def test_jar_path_with_special_chars(self):
shutil.copy2(mock_jar, dest)
stdout, stderr = self._run_connect_in_subprocess(dest)
self.assertEqual(stdout, 'OK', f'Connection failed: {stderr}')


class ParallelConnectTest(unittest.TestCase):
"""Test that parallel connect() calls are thread-safe (issue #60)."""

def test_parallel_connects_after_jvm_started(self):
"""Multiple threads connecting simultaneously should not crash."""
errors = []

def connect_thread(idx):
import jpype
try:
conn = jaydebeapiarrow.connect(
'org.jaydebeapi.mockdriver.MockDriver',
'jdbc:jaydebeapi://dummyurl%d' % idx)
# Verify the connection is usable
self.assertIsNotNone(conn)
conn.close()
except Exception as e:
errors.append(e)
finally:
if jpype.isThreadAttachedToJVM():
jpype.detachThreadFromJVM()

threads = []
for i in range(5):
t = threading.Thread(target=partial(connect_thread, i))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(errors, [], f"Thread errors: {errors}")

def test_jvm_startup_lock_exists(self):
"""The _jvm_startup_lock should be a threading.Lock."""
self.assertTrue(hasattr(jaydebeapiarrow, '_jvm_startup_lock'))
self.assertIsInstance(jaydebeapiarrow._jvm_startup_lock, type(threading.Lock()))
self.assertTrue(hasattr(jaydebeapiarrow, '_jvm_starting'))
Loading