Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions jaydebeapiarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,28 +591,38 @@ def _to_java(p):
return jpype.JClass("java.sql.Time").valueOf(p.isoformat())
if isinstance(p, Decimal):
return jpype.JClass("java.math.BigDecimal")(str(p))
if isinstance(p, list):
raise NotSupportedError(
"ARRAY type parameter binding is not supported. "
"Use server-side SQL functions to construct arrays, "
"or cast to VARCHAR in your query."
)
return p

def _to_java_array(p):
"""Convert a Python list to a java.sql.Array for setArray()."""
if not p:
str_arr = jpype.JArray(jpype.JString)(0)
else:
str_arr = jpype.JArray(jpype.JString)(
[str(x) if x is not None else None for x in p])
conn = statement.getConnection()
return conn.createArrayOf("VARCHAR", str_arr)

def _bind_param(statement, idx, p):
if isinstance(p, list):
statement.setArray(idx, _to_java_array(p))
else:
statement.setObject(idx, _to_java(p))

if is_batch:
for row in parameters:
for i, p in enumerate(row):
if p is None:
statement.setNull(i + 1, Types_NULL)
else:
statement.setObject(i + 1, _to_java(p))
_bind_param(statement, i + 1, p)
statement.addBatch()
else:
for i, p in enumerate(parameters):
if p is None:
statement.setNull(i + 1, Types_NULL)
else:
statement.setObject(i + 1, _to_java(p))
_bind_param(statement, i + 1, p)

def execute(self, operation, parameters=None):
if self._connection._closed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,14 @@ public final void mockType(String sqlTypesName) throws SQLException {

private List<Object[]> capturedSetObjectArgs;
private List<Object[]> capturedSetNullArgs;
private List<Object[]> capturedSetArrayArgs;

/** Set up a PreparedStatement that captures all setObject() and setNull() calls.
/** Set up a PreparedStatement that captures all setObject(), setNull() and setArray() calls.
* Rejects Arrow-stream binding to force the _set_stmt_parms_fallback path. */
public final void mockSetObjectCapture() throws SQLException {
capturedSetObjectArgs = new ArrayList<>();
capturedSetNullArgs = new ArrayList<>();
capturedSetArrayArgs = new ArrayList<>();
// Throw by default so Arrow primary path fails and fallback is triggered,
// but allow setNull() through (needed for NULL parameter binding tests).
PreparedStatement mockPreparedStatement = Mockito.mock(PreparedStatement.class,
Expand All @@ -388,6 +390,7 @@ public final void mockSetObjectCapture() throws SQLException {
});
Mockito.doReturn(true).when(mockPreparedStatement).execute();
Mockito.doNothing().when(mockPreparedStatement).close();
Mockito.doReturn(this).when(mockPreparedStatement).getConnection();
mockResultSet = Mockito.mock(ResultSet.class, "ResultSet(for setObject capture)");
Mockito.doReturn(mockResultSet).when(mockPreparedStatement).getResultSet();
Mockito.doReturn(false).when(mockResultSet).next();
Expand All @@ -402,6 +405,16 @@ public final void mockSetObjectCapture() throws SQLException {
capturedSetNullArgs.add(new Object[]{invocation.getArgument(0), invocation.getArgument(1)});
return null;
}).when(mockPreparedStatement).setNull(Mockito.anyInt(), Mockito.anyInt());
Mockito.doAnswer(invocation -> {
capturedSetArrayArgs.add(new Object[]{invocation.getArgument(0), invocation.getArgument(1)});
return null;
}).when(mockPreparedStatement).setArray(Mockito.anyInt(), Mockito.any());
// Stub createArrayOf to return a mock java.sql.Array
Mockito.doAnswer(invocation -> {
java.sql.Array mockArray = Mockito.mock(java.sql.Array.class);
Mockito.doReturn(invocation.getArgument(0)).when(mockArray).getBaseTypeName();
return mockArray;
}).when(this).createArrayOf(Mockito.anyString(), Mockito.any());
Mockito.doReturn(mockPreparedStatement).when(this).prepareStatement(Mockito.any());
}

Expand Down Expand Up @@ -429,6 +442,10 @@ public final List<Object[]> getCapturedSetNullArgs() {
return capturedSetNullArgs;
}

public final List<Object[]> getCapturedSetArrayArgs() {
return capturedSetArrayArgs;
}


public final void mockTimestampResult(LocalDateTime localDateTime) throws SQLException {
PreparedStatement mockPreparedStatement = Mockito.mock(PreparedStatement.class);
Expand All @@ -446,6 +463,13 @@ public final void mockTimestampResult(LocalDateTime localDateTime) throws SQLExc
Mockito.when(this.prepareStatement(Mockito.any())).thenReturn(mockPreparedStatement);
}

@Override
public java.sql.Array createArrayOf(String typeName, Object[] elements) throws SQLException {
java.sql.Array mockArray = Mockito.mock(java.sql.Array.class);
Mockito.doReturn(typeName).when(mockArray).getBaseTypeName();
return mockArray;
}

public final ResultSet verifyResultSet() {
return Mockito.verify(mockResultSet);
}
Expand Down
43 changes: 43 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,49 @@ def test_rollback_with_autocommit_disabled(self):
self.conn.jconn.setAutoCommit(False)
self.conn.rollback()

def test_array_parameter_binding(self):
"""Python list parameters should be bound as VARCHAR ARRAY via setArray()."""
with self.conn.cursor() as cursor:
cursor.execute("CREATE TABLE test_array_binding (id INT, tags VARCHAR(100) ARRAY)")
try:
cursor.execute(
"INSERT INTO test_array_binding (id, tags) VALUES (?, ?)",
(1, ["foo", "bar", "baz"])
)
self.assertEqual(cursor.rowcount, 1)
# Verify the row was inserted by checking the id
cursor.execute("SELECT id FROM test_array_binding WHERE id = ?", (1,))
result = cursor.fetchone()
self.assertEqual(result[0], 1)
# Verify array contents via SQL CARDINALITY function
cursor.execute(
"SELECT CARDINALITY(tags) FROM test_array_binding WHERE id = ?",
(1,)
)
result = cursor.fetchone()
self.assertEqual(result[0], 3)
finally:
cursor.execute("DROP TABLE test_array_binding")

def test_empty_array_parameter_binding(self):
"""Empty Python list should bind as an empty VARCHAR ARRAY."""
with self.conn.cursor() as cursor:
cursor.execute("CREATE TABLE test_array_binding (id INT, tags VARCHAR(100) ARRAY)")
try:
cursor.execute(
"INSERT INTO test_array_binding (id, tags) VALUES (?, ?)",
(2, [])
)
self.assertEqual(cursor.rowcount, 1)
cursor.execute(
"SELECT CARDINALITY(tags) FROM test_array_binding WHERE id = ?",
(2,)
)
result = cursor.fetchone()
self.assertEqual(result[0], 0)
finally:
cursor.execute("DROP TABLE test_array_binding")


class PostgresTest(IntegrationTestBase, unittest.TestCase):

Expand Down
36 changes: 32 additions & 4 deletions test/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,12 +710,40 @@ def test_to_java_str_passthrough(self):
self.assertEqual(len(captured), 1)
self.assertEqual(captured[0][1], "hello")

def test_to_java_list_raises_not_supported(self):
"""list should raise NotSupportedError for ARRAY binding."""
def test_to_java_list_binds_as_array(self):
"""list should be bound via setArray() for ARRAY type support."""
self.conn.jconn.mockSetObjectCapture()
with self.conn.cursor() as cursor:
with self.assertRaises(jaydebeapiarrow.NotSupportedError):
cursor.execute("dummy stmt", ([1, 2, 3],))
cursor.execute("dummy stmt", (["foo", "bar"],))
captured = self.conn.jconn.getCapturedSetArrayArgs()
self.assertEqual(len(captured), 1)
self.assertEqual(captured[0][0], 1) # parameter index (1-based)
import jpype
self.assertIsInstance(captured[0][1], jpype.JClass("java.sql.Array"))

def test_to_java_list_empty_binds_as_array(self):
"""Empty list should be bound via setArray()."""
self.conn.jconn.mockSetObjectCapture()
with self.conn.cursor() as cursor:
cursor.execute("dummy stmt", ([],))
captured = self.conn.jconn.getCapturedSetArrayArgs()
self.assertEqual(len(captured), 1)

def test_to_java_list_with_none_elements(self):
"""List containing None should be bound via setArray()."""
self.conn.jconn.mockSetObjectCapture()
with self.conn.cursor() as cursor:
cursor.execute("dummy stmt", (["a", None, "b"],))
captured = self.conn.jconn.getCapturedSetArrayArgs()
self.assertEqual(len(captured), 1)

def test_to_java_list_with_int_elements(self):
"""List of ints should be stringified and bound as VARCHAR array."""
self.conn.jconn.mockSetObjectCapture()
with self.conn.cursor() as cursor:
cursor.execute("dummy stmt", ([1, 2, 3],))
captured = self.conn.jconn.getCapturedSetArrayArgs()
self.assertEqual(len(captured), 1)

# --- Binary data round-trip tests ---

Expand Down
Loading