Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,19 +436,22 @@ def _bind_to_class(self, toml_dict: Mapping[str, Any], instance: T | None, conte
else:
return replace(instance, **parsed) # type: ignore[type-var]

def format_toml(self) -> Iterator[str]:
def format_toml(self, context: str = "") -> Iterator[str]:
"""
Yield lines of TOML text for populating the dataclass or object that we are binding to.

If we are binding to an object, non-default values from that object will be output.

If we are binding to a class, example values for mandatory fields will be derived from the field types;
these example values can be syntactically incorrect placeholders.

The `context` parameter contains a dot-separated key path for the bound object/class:
this will be prefixed to all yielded TOML table names.
"""

return self._format_toml_root(template=False)
return self._format_toml_root(context=context, template=False)

def format_toml_template(self) -> Iterator[str]:
def format_toml_template(self, context: str = "") -> Iterator[str]:
"""
Yield lines of TOML text as a template for populating the dataclass or object that we are binding to.

Expand All @@ -457,12 +460,15 @@ def format_toml_template(self) -> Iterator[str]:

If we are binding to an object, values from that object will be used to populate the template.
If we are binding to a class, example values will be derived from the field types.

The `context` parameter contains a dot-separated key path for the bound object/class:
this will be prefixed to all yielded TOML table names.
"""

return self._format_toml_root(template=True)
return self._format_toml_root(context=context, template=True)

def _format_toml_root(self, *, template: bool) -> Iterator[str]:
table = Table(self, "", self._instance, None)
def _format_toml_root(self, *, context: str, template: bool) -> Iterator[str]:
table = Table(self, context, self._instance, None)
lines = table.format_table(set(), template=template)
for line in lines:
if line:
Expand Down
25 changes: 25 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ class Inner:
behind_the_curtain: str = field(init=False, default="wizard")


@dataclass(kw_only=True)
class Outer:
top_level: int
nested: Inner


@pytest.mark.parametrize("optional", (True, False))
@pytest.mark.parametrize("string", (True, False))
def test_format_value_nested_dataclass(*, optional: bool, string: bool) -> None:
Expand All @@ -239,6 +245,25 @@ def test_format_value_nested_dataclass(*, optional: bool, string: bool) -> None:
assert round_trip_value(value, dc) == value


@pytest.mark.parametrize(
("context", "expected_headers"),
[
("", ["[nested]"]),
("one", ["[one]", "[one.nested]"]),
("one.two", ["[one.two]", "[one.two.nested]"]),
],
)
def test_format_value_context(context: str, expected_headers: list[str]) -> None:
obj = Outer(top_level=123, nested=Inner(key_containing_underscores=True, maybesuffix=timedelta(days=2)))

lines = list(Binder(obj).format_toml_template(context=context))
toml = "\n".join(lines)
print(toml) # noqa: T201

actual_headers = [line for line in lines if line.startswith("[")]
assert actual_headers == expected_headers


def test_format_value_unsupported_type() -> None:
with pytest.raises(TypeError, match=r"^NoneType$"):
format_toml_pair("unsupported", None)
Expand Down