Skip to content

Commit 574c0ab

Browse files
committed
Convert tests to pyarrow; fix render_altair() LazyFrame check
- Replace all pl.DataFrame/pl.LazyFrame usage in tests with pa.table() - Make polars optional in tests (skip polars-specific tests when not installed) - Custom readers now use duckdb .arrow() instead of .pl() - Fix render_altair() where the LazyFrame collect + DataFrame isinstance checks were accidentally collapsed into a single inverted check
1 parent 5c0f8ed commit 574c0ab

1 file changed

Lines changed: 76 additions & 51 deletions

File tree

tests/test_ggsql.py

Lines changed: 76 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
import duckdb
1414
import pyarrow as pa
1515
import pytest
16-
import polars as pl
1716
import altair
1817

1918
import ggsql
2019

21-
# Optional dependency for ibis test
20+
# Optional dependencies
21+
try:
22+
import polars as pl
23+
24+
HAS_POLARS = True
25+
except ImportError:
26+
HAS_POLARS = False
27+
2228
try:
2329
import ibis
2430

@@ -86,6 +92,17 @@ def test_register_and_query(self):
8692
assert isinstance(result, pa.Table)
8793
assert result.shape == (2, 2)
8894

95+
@pytest.mark.skipif(not HAS_POLARS, reason="polars not installed")
96+
def test_register_polars_dataframe(self):
97+
"""register() accepts polars DataFrames via automatic conversion."""
98+
reader = ggsql.DuckDBReader("duckdb://memory")
99+
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
100+
reader.register("my_data", df)
101+
102+
result = reader.execute_sql("SELECT * FROM my_data WHERE x > 1")
103+
assert isinstance(result, pa.Table)
104+
assert result.shape == (2, 2)
105+
89106
def test_invalid_connection_string(self):
90107
with pytest.raises(ValueError):
91108
ggsql.DuckDBReader("invalid://connection")
@@ -110,8 +127,8 @@ def test_execute_simple_query(self):
110127

111128
def test_execute_with_registered_data(self):
112129
reader = ggsql.DuckDBReader("duckdb://memory")
113-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
114-
reader.register("data", df)
130+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
131+
reader.register("data", table)
115132

116133
spec = reader.execute("SELECT * FROM data VISUALISE x, y DRAW point")
117134
assert spec.metadata()["rows"] == 3
@@ -169,8 +186,8 @@ def test_render_to_vegalite(self):
169186

170187
def test_render_contains_data(self):
171188
reader = ggsql.DuckDBReader("duckdb://memory")
172-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
173-
reader.register("data", df)
189+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
190+
reader.register("data", table)
174191

175192
spec = reader.execute("SELECT * FROM data VISUALISE x, y DRAW point")
176193
writer = ggsql.VegaLiteWriter()
@@ -198,11 +215,18 @@ def test_render_multi_layer(self):
198215
class TestRenderAltairDataFrameConversion:
199216
"""Tests for DataFrame handling in render_altair()."""
200217

218+
def test_accepts_pyarrow_table(self):
219+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
220+
chart = ggsql.render_altair(table, "VISUALISE x, y DRAW point")
221+
assert isinstance(chart, altair.TopLevelMixin)
222+
223+
@pytest.mark.skipif(not HAS_POLARS, reason="polars not installed")
201224
def test_accepts_polars_dataframe(self):
202225
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
203226
chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point")
204227
assert isinstance(chart, altair.TopLevelMixin)
205228

229+
@pytest.mark.skipif(not HAS_POLARS, reason="polars not installed")
206230
def test_accepts_polars_lazyframe(self):
207231
lf = pl.LazyFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
208232
chart = ggsql.render_altair(lf, "VISUALISE x, y DRAW point")
@@ -211,8 +235,8 @@ def test_accepts_polars_lazyframe(self):
211235
def test_accepts_narwhals_dataframe(self):
212236
import narwhals as nw
213237

214-
pl_df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
215-
nw_df = nw.from_native(pl_df)
238+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
239+
nw_df = nw.from_native(table)
216240

217241
chart = ggsql.render_altair(nw_df, "VISUALISE x, y DRAW point")
218242
assert isinstance(chart, altair.TopLevelMixin)
@@ -233,20 +257,20 @@ class TestRenderAltairReturnType:
233257
"""Tests for render_altair() return type."""
234258

235259
def test_returns_altair_chart(self):
236-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
237-
chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point")
260+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
261+
chart = ggsql.render_altair(table, "VISUALISE x, y DRAW point")
238262
assert isinstance(chart, altair.TopLevelMixin)
239263

240264
def test_chart_has_data(self):
241-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
242-
chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point")
265+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
266+
chart = ggsql.render_altair(table, "VISUALISE x, y DRAW point")
243267
spec = chart.to_dict()
244268
# Data should be embedded in datasets
245269
assert "datasets" in spec
246270

247271
def test_chart_can_be_serialized(self):
248-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
249-
chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point")
272+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
273+
chart = ggsql.render_altair(table, "VISUALISE x, y DRAW point")
250274
# Should not raise
251275
json_str = chart.to_json()
252276
assert len(json_str) > 0
@@ -257,15 +281,15 @@ class TestRenderAltairChartTypeDetection:
257281

258282
def test_simple_chart_returns_layer_chart(self):
259283
"""Simple DRAW specs produce LayerChart (ggsql always wraps in layer)."""
260-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
261-
chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point")
284+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
285+
chart = ggsql.render_altair(table, "VISUALISE x, y DRAW point")
262286
# ggsql wraps all charts in a layer
263287
assert isinstance(chart, altair.LayerChart)
264288

265289
def test_layered_chart_can_round_trip(self):
266290
"""LayerChart can be converted to dict and back."""
267-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
268-
chart = ggsql.render_altair(df, "VISUALISE x, y DRAW point")
291+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
292+
chart = ggsql.render_altair(table, "VISUALISE x, y DRAW point")
269293

270294
# Convert to dict and back
271295
spec = chart.to_dict()
@@ -277,7 +301,7 @@ def test_layered_chart_can_round_trip(self):
277301

278302
def test_faceted_chart_returns_facet_chart(self):
279303
"""FACET specs produce FacetChart."""
280-
df = pl.DataFrame(
304+
table = pa.table(
281305
{
282306
"x": [1, 2, 3, 4, 5, 6],
283307
"y": [10, 20, 30, 40, 50, 60],
@@ -286,21 +310,21 @@ def test_faceted_chart_returns_facet_chart(self):
286310
)
287311
# Need validate=False because ggsql produces v6 specs
288312
chart = ggsql.render_altair(
289-
df, "VISUALISE x, y FACET group DRAW point", validate=False
313+
table, "VISUALISE x, y FACET group DRAW point", validate=False
290314
)
291315
assert isinstance(chart, altair.FacetChart)
292316

293317
def test_faceted_chart_can_round_trip(self):
294318
"""FacetChart can be converted to dict and back."""
295-
df = pl.DataFrame(
319+
table = pa.table(
296320
{
297321
"x": [1, 2, 3, 4, 5, 6],
298322
"y": [10, 20, 30, 40, 50, 60],
299323
"group": ["A", "A", "A", "B", "B", "B"],
300324
}
301325
)
302326
chart = ggsql.render_altair(
303-
df, "VISUALISE x, y FACET group DRAW point", validate=False
327+
table, "VISUALISE x, y FACET group DRAW point", validate=False
304328
)
305329

306330
# Convert to dict (skip validation for ggsql specs)
@@ -313,14 +337,16 @@ def test_faceted_chart_can_round_trip(self):
313337

314338
def test_chart_with_color_encoding(self):
315339
"""Charts with color encoding still return correct type."""
316-
df = pl.DataFrame(
340+
table = pa.table(
317341
{
318342
"x": [1, 2, 3, 4],
319343
"y": [10, 20, 30, 40],
320344
"category": ["A", "B", "A", "B"],
321345
}
322346
)
323-
chart = ggsql.render_altair(df, "VISUALISE x, y, category AS color DRAW point")
347+
chart = ggsql.render_altair(
348+
table, "VISUALISE x, y, category AS color DRAW point"
349+
)
324350
# Should still be a LayerChart (ggsql wraps in layer)
325351
assert isinstance(chart, altair.LayerChart)
326352

@@ -329,9 +355,9 @@ class TestRenderAltairErrorHandling:
329355
"""Tests for error handling in render_altair()."""
330356

331357
def test_invalid_viz_raises(self):
332-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
358+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
333359
with pytest.raises(ValueError):
334-
ggsql.render_altair(df, "NOT VALID SYNTAX")
360+
ggsql.render_altair(table, "NOT VALID SYNTAX")
335361

336362

337363
class TestTwoStageAPIIntegration:
@@ -343,14 +369,14 @@ def test_end_to_end_workflow(self):
343369
reader = ggsql.DuckDBReader("duckdb://memory")
344370

345371
# Register data
346-
df = pl.DataFrame(
372+
table = pa.table(
347373
{
348374
"date": ["2024-01-01", "2024-01-02", "2024-01-03"],
349375
"value": [10, 20, 30],
350376
"region": ["North", "South", "North"],
351377
}
352378
)
353-
reader.register("sales", df)
379+
reader.register("sales", table)
354380

355381
# Execute visualization
356382
spec = reader.execute(
@@ -400,11 +426,11 @@ class RegisterReader:
400426
def __init__(self):
401427
self.conn = duckdb.connect()
402428

403-
def execute_sql(self, sql: str) -> pl.DataFrame:
404-
return self.conn.execute(sql).pl()
429+
def execute_sql(self, sql: str) -> pa.Table:
430+
return self.conn.execute(sql).arrow()
405431

406-
def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
407-
self.conn.register(name, df)
432+
def register(self, name: str, table: pa.Table, _replace: bool) -> None:
433+
self.conn.register(name, table)
408434

409435
reader = RegisterReader()
410436
spec = ggsql.execute("SELECT 1 AS x, 2 AS y VISUALISE x, y DRAW point", reader)
@@ -414,7 +440,7 @@ def test_custom_reader_error_handling(self):
414440
"""Custom reader errors are propagated."""
415441

416442
class ErrorReader:
417-
def execute_sql(self, sql: str) -> pl.DataFrame:
443+
def execute_sql(self, sql: str) -> pa.Table:
418444
raise ValueError("Custom reader error")
419445

420446
reader = ErrorReader()
@@ -426,7 +452,7 @@ def test_custom_reader_wrong_return_type(self):
426452

427453
class WrongTypeReader:
428454
def execute_sql(self, sql: str):
429-
return {"x": [1, 2, 3]} # dict, not DataFrame
455+
return {"x": [1, 2, 3]} # dict, not Table
430456

431457
reader = WrongTypeReader()
432458
with pytest.raises((ValueError, TypeError)):
@@ -451,11 +477,11 @@ def __init__(self):
451477
") AS t(x, y, category)"
452478
)
453479

454-
def execute_sql(self, sql: str) -> pl.DataFrame:
455-
return self.conn.execute(sql).pl()
480+
def execute_sql(self, sql: str) -> pa.Table:
481+
return self.conn.execute(sql).arrow()
456482

457-
def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
458-
self.conn.register(name, df)
483+
def register(self, name: str, table: pa.Table, _replace: bool) -> None:
484+
self.conn.register(name, table)
459485

460486
reader = DuckDBBackedReader()
461487
spec = ggsql.execute(
@@ -481,12 +507,12 @@ def __init__(self):
481507
)
482508
self.execute_calls = []
483509

484-
def execute_sql(self, sql: str) -> pl.DataFrame:
510+
def execute_sql(self, sql: str) -> pa.Table:
485511
self.execute_calls.append(sql)
486-
return self.conn.execute(sql).pl()
512+
return self.conn.execute(sql).arrow()
487513

488-
def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
489-
self.conn.register(name, df)
514+
def register(self, name: str, table: pa.Table, _replace: bool) -> None:
515+
self.conn.register(name, table)
490516

491517
reader = RecordingReader()
492518
ggsql.execute(
@@ -507,20 +533,20 @@ class IbisReader:
507533
def __init__(self):
508534
self.con = ibis.duckdb.connect()
509535

510-
def execute_sql(self, sql: str) -> pl.DataFrame:
511-
return self.con.con.execute(sql).pl()
536+
def execute_sql(self, sql: str) -> pa.Table:
537+
return self.con.con.execute(sql).arrow()
512538

513539
def register(
514-
self, name: str, df: pl.DataFrame, replace: bool = True
540+
self, name: str, table: pa.Table, replace: bool = True
515541
) -> None:
516-
self.con.create_table(name, df.to_arrow(), overwrite=replace)
542+
self.con.create_table(name, table, overwrite=replace)
517543

518544
def unregister(self, name: str) -> None:
519545
self.con.drop_table(name)
520546

521547
reader = IbisReader()
522-
df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]})
523-
reader.register("mydata", df)
548+
table = pa.table({"x": [1, 2, 3], "y": [10, 20, 30]})
549+
reader.register("mydata", table)
524550

525551
spec = ggsql.execute(
526552
"SELECT * FROM mydata VISUALISE x, y DRAW point",
@@ -555,18 +581,17 @@ def test_render_chart_layer(self):
555581
def test_render_chart_facet(self):
556582
"""render_chart() returns FacetChart for faceted specs."""
557583
reader = ggsql.DuckDBReader("duckdb://memory")
558-
df = pl.DataFrame(
584+
table = pa.table(
559585
{
560586
"x": [1, 2, 3, 4, 5, 6],
561587
"y": [10, 20, 30, 40, 50, 60],
562588
"group": ["A", "A", "A", "B", "B", "B"],
563589
}
564590
)
565-
reader.register("data", df)
591+
reader.register("data", table)
566592
spec = reader.execute(
567593
"SELECT * FROM data VISUALISE x, y FACET group DRAW point"
568594
)
569595
writer = ggsql.VegaLiteWriter()
570596
chart = writer.render_chart(spec, validate=False)
571597
assert isinstance(chart, altair.FacetChart)
572-

0 commit comments

Comments
 (0)