Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ iron_sql keeps SQL close to Python call sites while giving you typed, async quer
Provide the schema file and DSN import string, then call `generate_sql_package()` with:
- `schema_path`: path to the schema SQL file (relative to `src_path`).
- `package_full_name`: target module, e.g. `myapp.db`.
- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.value`.
- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.get_value()`.
- `src_path`: optional base source path for scanning queries (defaults current directory).
- `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`).
- `tempdir_path`: optional path for temporary file generation (useful for Docker mounts).
Expand Down
4 changes: 2 additions & 2 deletions src/iron_sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from dataclasses import dataclass
from operator import attrgetter
from pathlib import Path

import inflection
Expand Down Expand Up @@ -90,7 +89,8 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
queries = list(find_all_queries(src_path, sql_fn_name))
queries = list({q.name: q for q in queries}.values())

dsn = attrgetter(dsn_import_path)(importlib.import_module(dsn_import_package))
dsn_package = importlib.import_module(dsn_import_package)
dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307

sqlc_res = run_sqlc(
src_path / schema_path,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(
self.dsn = dsn
self.test_name = test_name
self.schema_path = schema_path
self._sqlc = sqlc
self.sqlc = sqlc
self.pkg_name = f"testapp_{test_name}.testdb"
self.src_path = root / "src"
self.app_pkg = f"testapp_{test_name}"
Expand Down Expand Up @@ -225,7 +225,7 @@ def generate_no_import(self) -> bool:
dsn_import=f"{self.app_pkg}.config:DSN",
src_path=self.src_path,
tempdir_path=self.src_path,
sqlc_command=self._sqlc.sqlc_command(),
sqlc_command=self.sqlc.sqlc_command(),
)

def generate(self) -> Any:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_code_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import sys
from pathlib import Path

import pytest

from iron_sql import generate_sql_package
from tests.conftest import ProjectBuilder


Expand Down Expand Up @@ -116,3 +120,38 @@ async def test_special_types_params(test_project: ProjectBuilder) -> None:
"VALUES ($1, $2, $3, $4, $5, $6)",
)
assert test_project.generate_no_import() is True


def test_dsn_import_with_function_call(test_project: ProjectBuilder) -> None:
(test_project.app_dir / "config.py").write_text(
f"""
class Config:
def __init__(self, dsn: str):
self._dsn = dsn
def get_dsn(self) -> str:
return self._dsn

CONFIG = Config("{test_project.dsn}")
""",
encoding="utf-8",
)

test_project.add_query("q", "SELECT 1 as value")

if str(test_project.src_path) not in sys.path:
sys.path.insert(0, str(test_project.src_path))

generate_sql_package(
schema_path=Path("schema.sql"),
package_full_name=test_project.pkg_name,
dsn_import=f"{test_project.app_pkg}.config:CONFIG.get_dsn()",
src_path=test_project.src_path,
tempdir_path=test_project.src_path,
sqlc_command=test_project.sqlc.sqlc_command(),
)

generated_path = (
test_project.src_path / f"{test_project.pkg_name.replace('.', '/')}.py"
)
generated = generated_path.read_text()
assert "CONFIG.get_dsn()" in generated