diff --git a/README.md b/README.md index b70564a..4bd2abb 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ iron_sql keeps SQL close to Python call sites while giving you typed, async quer Provide the schema file and DSN import string, then call `generate_sql_package()` with: - `schema_path`: path to the schema SQL file (relative to `src_path`). - `package_full_name`: target module, e.g. `myapp.db`. -- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.value`. +- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.get_value()`. - `src_path`: optional base source path for scanning queries (defaults current directory). - `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`). - `tempdir_path`: optional path for temporary file generation (useful for Docker mounts). diff --git a/src/iron_sql/generator.py b/src/iron_sql/generator.py index a135c3c..ae46683 100644 --- a/src/iron_sql/generator.py +++ b/src/iron_sql/generator.py @@ -7,7 +7,6 @@ from collections.abc import Callable from collections.abc import Iterator from dataclasses import dataclass -from operator import attrgetter from pathlib import Path import inflection @@ -90,7 +89,8 @@ def generate_sql_package( # noqa: PLR0913, PLR0914 queries = list(find_all_queries(src_path, sql_fn_name)) queries = list({q.name: q for q in queries}.values()) - dsn = attrgetter(dsn_import_path)(importlib.import_module(dsn_import_package)) + dsn_package = importlib.import_module(dsn_import_package) + dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307 sqlc_res = run_sqlc( src_path / schema_path, diff --git a/tests/conftest.py b/tests/conftest.py index 2a3e000..503a65d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,7 +157,7 @@ def __init__( self.dsn = dsn self.test_name = test_name self.schema_path = schema_path - self._sqlc = sqlc + self.sqlc = sqlc self.pkg_name = f"testapp_{test_name}.testdb" self.src_path = root / "src" self.app_pkg = f"testapp_{test_name}" @@ -225,7 +225,7 @@ def generate_no_import(self) -> bool: dsn_import=f"{self.app_pkg}.config:DSN", src_path=self.src_path, tempdir_path=self.src_path, - sqlc_command=self._sqlc.sqlc_command(), + sqlc_command=self.sqlc.sqlc_command(), ) def generate(self) -> Any: diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 2405d46..2387d0c 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -1,5 +1,9 @@ +import sys +from pathlib import Path + import pytest +from iron_sql import generate_sql_package from tests.conftest import ProjectBuilder @@ -116,3 +120,38 @@ async def test_special_types_params(test_project: ProjectBuilder) -> None: "VALUES ($1, $2, $3, $4, $5, $6)", ) assert test_project.generate_no_import() is True + + +def test_dsn_import_with_function_call(test_project: ProjectBuilder) -> None: + (test_project.app_dir / "config.py").write_text( + f""" +class Config: + def __init__(self, dsn: str): + self._dsn = dsn + def get_dsn(self) -> str: + return self._dsn + +CONFIG = Config("{test_project.dsn}") +""", + encoding="utf-8", + ) + + test_project.add_query("q", "SELECT 1 as value") + + if str(test_project.src_path) not in sys.path: + sys.path.insert(0, str(test_project.src_path)) + + generate_sql_package( + schema_path=Path("schema.sql"), + package_full_name=test_project.pkg_name, + dsn_import=f"{test_project.app_pkg}.config:CONFIG.get_dsn()", + src_path=test_project.src_path, + tempdir_path=test_project.src_path, + sqlc_command=test_project.sqlc.sqlc_command(), + ) + + generated_path = ( + test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py" + ) + generated = generated_path.read_text() + assert "CONFIG.get_dsn()" in generated