diff --git a/jaydebeapiarrow/__init__.py b/jaydebeapiarrow/__init__.py index 9b1f8053..e4d028f6 100644 --- a/jaydebeapiarrow/__init__.py +++ b/jaydebeapiarrow/__init__.py @@ -591,28 +591,38 @@ def _to_java(p): return jpype.JClass("java.sql.Time").valueOf(p.isoformat()) if isinstance(p, Decimal): return jpype.JClass("java.math.BigDecimal")(str(p)) - if isinstance(p, list): - raise NotSupportedError( - "ARRAY type parameter binding is not supported. " - "Use server-side SQL functions to construct arrays, " - "or cast to VARCHAR in your query." - ) return p + def _to_java_array(p): + """Convert a Python list to a java.sql.Array for setArray().""" + if not p: + str_arr = jpype.JArray(jpype.JString)(0) + else: + str_arr = jpype.JArray(jpype.JString)( + [str(x) if x is not None else None for x in p]) + conn = statement.getConnection() + return conn.createArrayOf("VARCHAR", str_arr) + + def _bind_param(statement, idx, p): + if isinstance(p, list): + statement.setArray(idx, _to_java_array(p)) + else: + statement.setObject(idx, _to_java(p)) + if is_batch: for row in parameters: for i, p in enumerate(row): if p is None: statement.setNull(i + 1, Types_NULL) else: - statement.setObject(i + 1, _to_java(p)) + _bind_param(statement, i + 1, p) statement.addBatch() else: for i, p in enumerate(parameters): if p is None: statement.setNull(i + 1, Types_NULL) else: - statement.setObject(i + 1, _to_java(p)) + _bind_param(statement, i + 1, p) def execute(self, operation, parameters=None): if self._connection._closed: diff --git a/mockdriver/src/main/java/org/jaydebeapi/mockdriver/MockConnection.java b/mockdriver/src/main/java/org/jaydebeapi/mockdriver/MockConnection.java index 2468344a..3e838fa6 100644 --- a/mockdriver/src/main/java/org/jaydebeapi/mockdriver/MockConnection.java +++ b/mockdriver/src/main/java/org/jaydebeapi/mockdriver/MockConnection.java @@ -371,12 +371,14 @@ public final void mockType(String sqlTypesName) throws SQLException { private List capturedSetObjectArgs; private List capturedSetNullArgs; + private List capturedSetArrayArgs; - /** Set up a PreparedStatement that captures all setObject() and setNull() calls. + /** Set up a PreparedStatement that captures all setObject(), setNull() and setArray() calls. * Rejects Arrow-stream binding to force the _set_stmt_parms_fallback path. */ public final void mockSetObjectCapture() throws SQLException { capturedSetObjectArgs = new ArrayList<>(); capturedSetNullArgs = new ArrayList<>(); + capturedSetArrayArgs = new ArrayList<>(); // Throw by default so Arrow primary path fails and fallback is triggered, // but allow setNull() through (needed for NULL parameter binding tests). PreparedStatement mockPreparedStatement = Mockito.mock(PreparedStatement.class, @@ -388,6 +390,7 @@ public final void mockSetObjectCapture() throws SQLException { }); Mockito.doReturn(true).when(mockPreparedStatement).execute(); Mockito.doNothing().when(mockPreparedStatement).close(); + Mockito.doReturn(this).when(mockPreparedStatement).getConnection(); mockResultSet = Mockito.mock(ResultSet.class, "ResultSet(for setObject capture)"); Mockito.doReturn(mockResultSet).when(mockPreparedStatement).getResultSet(); Mockito.doReturn(false).when(mockResultSet).next(); @@ -402,6 +405,16 @@ public final void mockSetObjectCapture() throws SQLException { capturedSetNullArgs.add(new Object[]{invocation.getArgument(0), invocation.getArgument(1)}); return null; }).when(mockPreparedStatement).setNull(Mockito.anyInt(), Mockito.anyInt()); + Mockito.doAnswer(invocation -> { + capturedSetArrayArgs.add(new Object[]{invocation.getArgument(0), invocation.getArgument(1)}); + return null; + }).when(mockPreparedStatement).setArray(Mockito.anyInt(), Mockito.any()); + // Stub createArrayOf to return a mock java.sql.Array + Mockito.doAnswer(invocation -> { + java.sql.Array mockArray = Mockito.mock(java.sql.Array.class); + Mockito.doReturn(invocation.getArgument(0)).when(mockArray).getBaseTypeName(); + return mockArray; + }).when(this).createArrayOf(Mockito.anyString(), Mockito.any()); Mockito.doReturn(mockPreparedStatement).when(this).prepareStatement(Mockito.any()); } @@ -429,6 +442,10 @@ public final List getCapturedSetNullArgs() { return capturedSetNullArgs; } + public final List getCapturedSetArrayArgs() { + return capturedSetArrayArgs; + } + public final void mockTimestampResult(LocalDateTime localDateTime) throws SQLException { PreparedStatement mockPreparedStatement = Mockito.mock(PreparedStatement.class); @@ -446,6 +463,13 @@ public final void mockTimestampResult(LocalDateTime localDateTime) throws SQLExc Mockito.when(this.prepareStatement(Mockito.any())).thenReturn(mockPreparedStatement); } + @Override + public java.sql.Array createArrayOf(String typeName, Object[] elements) throws SQLException { + java.sql.Array mockArray = Mockito.mock(java.sql.Array.class); + Mockito.doReturn(typeName).when(mockArray).getBaseTypeName(); + return mockArray; + } + public final ResultSet verifyResultSet() { return Mockito.verify(mockResultSet); } diff --git a/test/test_integration.py b/test/test_integration.py index c9242f0d..71d50504 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -999,6 +999,49 @@ def test_rollback_with_autocommit_disabled(self): self.conn.jconn.setAutoCommit(False) self.conn.rollback() + def test_array_parameter_binding(self): + """Python list parameters should be bound as VARCHAR ARRAY via setArray().""" + with self.conn.cursor() as cursor: + cursor.execute("CREATE TABLE test_array_binding (id INT, tags VARCHAR(100) ARRAY)") + try: + cursor.execute( + "INSERT INTO test_array_binding (id, tags) VALUES (?, ?)", + (1, ["foo", "bar", "baz"]) + ) + self.assertEqual(cursor.rowcount, 1) + # Verify the row was inserted by checking the id + cursor.execute("SELECT id FROM test_array_binding WHERE id = ?", (1,)) + result = cursor.fetchone() + self.assertEqual(result[0], 1) + # Verify array contents via SQL CARDINALITY function + cursor.execute( + "SELECT CARDINALITY(tags) FROM test_array_binding WHERE id = ?", + (1,) + ) + result = cursor.fetchone() + self.assertEqual(result[0], 3) + finally: + cursor.execute("DROP TABLE test_array_binding") + + def test_empty_array_parameter_binding(self): + """Empty Python list should bind as an empty VARCHAR ARRAY.""" + with self.conn.cursor() as cursor: + cursor.execute("CREATE TABLE test_array_binding (id INT, tags VARCHAR(100) ARRAY)") + try: + cursor.execute( + "INSERT INTO test_array_binding (id, tags) VALUES (?, ?)", + (2, []) + ) + self.assertEqual(cursor.rowcount, 1) + cursor.execute( + "SELECT CARDINALITY(tags) FROM test_array_binding WHERE id = ?", + (2,) + ) + result = cursor.fetchone() + self.assertEqual(result[0], 0) + finally: + cursor.execute("DROP TABLE test_array_binding") + class PostgresTest(IntegrationTestBase, unittest.TestCase): diff --git a/test/test_mock.py b/test/test_mock.py index c45843d4..50a35235 100644 --- a/test/test_mock.py +++ b/test/test_mock.py @@ -710,12 +710,40 @@ def test_to_java_str_passthrough(self): self.assertEqual(len(captured), 1) self.assertEqual(captured[0][1], "hello") - def test_to_java_list_raises_not_supported(self): - """list should raise NotSupportedError for ARRAY binding.""" + def test_to_java_list_binds_as_array(self): + """list should be bound via setArray() for ARRAY type support.""" self.conn.jconn.mockSetObjectCapture() with self.conn.cursor() as cursor: - with self.assertRaises(jaydebeapiarrow.NotSupportedError): - cursor.execute("dummy stmt", ([1, 2, 3],)) + cursor.execute("dummy stmt", (["foo", "bar"],)) + captured = self.conn.jconn.getCapturedSetArrayArgs() + self.assertEqual(len(captured), 1) + self.assertEqual(captured[0][0], 1) # parameter index (1-based) + import jpype + self.assertIsInstance(captured[0][1], jpype.JClass("java.sql.Array")) + + def test_to_java_list_empty_binds_as_array(self): + """Empty list should be bound via setArray().""" + self.conn.jconn.mockSetObjectCapture() + with self.conn.cursor() as cursor: + cursor.execute("dummy stmt", ([],)) + captured = self.conn.jconn.getCapturedSetArrayArgs() + self.assertEqual(len(captured), 1) + + def test_to_java_list_with_none_elements(self): + """List containing None should be bound via setArray().""" + self.conn.jconn.mockSetObjectCapture() + with self.conn.cursor() as cursor: + cursor.execute("dummy stmt", (["a", None, "b"],)) + captured = self.conn.jconn.getCapturedSetArrayArgs() + self.assertEqual(len(captured), 1) + + def test_to_java_list_with_int_elements(self): + """List of ints should be stringified and bound as VARCHAR array.""" + self.conn.jconn.mockSetObjectCapture() + with self.conn.cursor() as cursor: + cursor.execute("dummy stmt", ([1, 2, 3],)) + captured = self.conn.jconn.getCapturedSetArrayArgs() + self.assertEqual(len(captured), 1) # --- Binary data round-trip tests ---