diff --git a/dataframely/_base_schema.py b/dataframely/_base_schema.py index 7ae64e3..8c9c1d2 100644 --- a/dataframely/_base_schema.py +++ b/dataframely/_base_schema.py @@ -3,6 +3,8 @@ from __future__ import annotations +import ast +import inspect import sys import textwrap from abc import ABCMeta @@ -30,6 +32,58 @@ # --------------------------------------- UTILS -------------------------------------- # +def _extract_column_docstrings(cls: type) -> dict[str, str]: + """Extract docstrings for class attributes from source code. + + This function parses the source code of a class to find string literals + that immediately follow attribute assignments. These are treated as + documentation strings for those attributes. + + Args: + cls: The class to extract docstrings from. + + Returns: + A dictionary mapping attribute names to their docstrings. + """ + try: + source = inspect.getsource(cls) + # Dedent to handle indented class definitions + tree = ast.parse(textwrap.dedent(source)) + + # Find the class definition + class_def = None + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + class_def = node + break + + if not class_def: + return {} + + # Extract docstrings that appear after assignments + docstrings = {} + for i in range(len(class_def.body) - 1): + current = class_def.body[i] + next_stmt = class_def.body[i + 1] + + # Check if current is an assignment and next is a string constant + if ( + isinstance(current, ast.Assign) + and isinstance(next_stmt, ast.Expr) + and isinstance(next_stmt.value, ast.Constant) + and isinstance(next_stmt.value.value, str) + ): + # Get the target name(s) + for target in current.targets: + if isinstance(target, ast.Name): + docstrings[target.id] = next_stmt.value.value + + return docstrings + except (OSError, TypeError, SyntaxError): + # Source not available or cannot be parsed + return {} + + def _build_rules( custom: dict[str, Rule], columns: dict[str, Column], *, with_cast: bool ) -> dict[str, Rule]: @@ -104,6 +158,21 @@ def __new__( namespace[_COLUMN_ATTR] = result.columns cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs) + # Extract and attach docstrings to columns + docstrings = _extract_column_docstrings(cls) + for col_name, col in result.columns.items(): + # Use the original attribute name (not alias) to match docstrings + original_name = None + for attr, value in namespace.items(): + if value is col: + original_name = attr + break + + # If we found a docstring for this column and it doesn't already have one, + # attach it + if original_name and original_name in docstrings and col.doc is None: + col.doc = docstrings[original_name] + # Assign rules retroactively as we only encounter rule factories in the result rules = {name: factory.make(cls) for name, factory in result.rules.items()} setattr(cls, _RULE_ATTR, rules) diff --git a/dataframely/columns/_base.py b/dataframely/columns/_base.py index e01a1c8..8514181 100644 --- a/dataframely/columns/_base.py +++ b/dataframely/columns/_base.py @@ -47,6 +47,7 @@ def __init__( check: Check | None = None, alias: str | None = None, metadata: dict[str, Any] | None = None, + doc: str | None = None, ): """ Args: @@ -70,6 +71,9 @@ def __init__( this option does _not_ allow to refer to the column with two different names, the specified alias is the only valid name. metadata: A dictionary of metadata to attach to the column. + doc: A documentation string for the column. This can be automatically + extracted from a docstring placed immediately after the column definition + in a schema class. """ if nullable and primary_key: @@ -80,6 +84,7 @@ def __init__( self.check = check self.alias = alias self.metadata = metadata + self.doc = doc # The name may be overridden by the schema on column access. self._name = "" @@ -299,7 +304,7 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: if self.__class__.__name__ not in _TYPE_MAPPING: raise ValueError("Cannot serialize non-native dataframely column types.") - return { + result = { "column_type": self.__class__.__name__, **{ param: ( @@ -312,6 +317,11 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]: }, } + # Always include doc from the base Column class even if not in subclass signature + result["doc"] = self.doc + + return result + @classmethod def from_dict(cls, data: dict[str, Any]) -> Self: """Read the column definition from a dictionary. @@ -325,13 +335,22 @@ def from_dict(cls, data: dict[str, Any]) -> Self: Attention: This method is only intended for internal use. """ - return cls( - **{ - k: (cast(Any, _check_from_expr(v)) if k == "check" else v) - for k, v in data.items() - if k != "column_type" - } - ) + # Extract doc separately since it may not be in the subclass signature + doc_value = data.get("doc") + + # Create the column with parameters that match its __init__ signature + column_data = { + k: (cast(Any, _check_from_expr(v)) if k == "check" else v) + for k, v in data.items() + if k not in ("column_type", "doc") + } + + column = cls(**column_data) + + # Set doc attribute directly if it was in the serialized data + column.doc = doc_value + + return column # ----------------------------------- EQUALITY ----------------------------------- # @@ -350,7 +369,8 @@ def matches(self, other: Column, expr: pl.Expr) -> bool: return False attributes = inspect.signature(self.__class__.__init__) - return all( + # Check all attributes in the signature + sig_match = all( self._attributes_match( getattr(self, attr), getattr(other, attr), attr, expr ) @@ -361,6 +381,9 @@ def matches(self, other: Column, expr: pl.Expr) -> bool: if attr not in ("self", "alias") ) + # Also check the doc attribute from the base Column class + return sig_match and self.doc == other.doc + def _attributes_match( self, lhs: Any, rhs: Any, name: str, column_expr: pl.Expr ) -> bool: @@ -384,6 +407,11 @@ def __repr__(self) -> str: getattr(self, attribute) == param_details.default ) ] + + # Also include doc from base Column class if it's not None + if self.doc is not None: + parts.append(f"doc={repr(self.doc)}") + return f"{self.__class__.__name__}({', '.join(parts)})" def __str__(self) -> str: diff --git a/tests/columns/test_docstrings.py b/tests/columns/test_docstrings.py new file mode 100644 index 0000000..564ddf0 --- /dev/null +++ b/tests/columns/test_docstrings.py @@ -0,0 +1,219 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import polars as pl + +import dataframely as dy + + +def test_column_docstring_basic() -> None: + """Test basic column docstring extraction.""" + + class MySchema(dy.Schema): + """Schema docstring.""" + + col1 = dy.String(nullable=False) + """This is the documentation for col1.""" + + col2 = dy.Integer() + """This is the documentation for col2.""" + + col3 = dy.Float64() + + columns = MySchema.columns() + assert columns["col1"].doc == "This is the documentation for col1." + assert columns["col2"].doc == "This is the documentation for col2." + assert columns["col3"].doc is None + + +def test_column_docstring_multiline() -> None: + """Test multiline column docstrings.""" + + class MySchema(dy.Schema): + col1 = dy.String() + """This is a multiline docstring. + + It has multiple lines and paragraphs. + """ + + col2 = dy.Integer() + """Single line after col2.""" + + columns = MySchema.columns() + assert columns["col1"].doc is not None and "multiline" in columns["col1"].doc + assert columns["col1"].doc is not None and "multiple lines" in columns["col1"].doc + assert columns["col2"].doc == "Single line after col2." + + +def test_column_docstring_with_alias() -> None: + """Test column docstrings work with aliased columns.""" + + class MySchema(dy.Schema): + col_python_name = dy.String(alias="col-sql-name") + """Documentation for the aliased column.""" + + columns = MySchema.columns() + # The column is stored under its alias + assert "col-sql-name" in columns + assert columns["col-sql-name"].doc == "Documentation for the aliased column." + + +def test_column_docstring_with_inheritance() -> None: + """Test column docstrings with schema inheritance.""" + + class BaseSchema(dy.Schema): + base_col = dy.String() + """Base column documentation.""" + + class ChildSchema(BaseSchema): + child_col = dy.Integer() + """Child column documentation.""" + + columns = ChildSchema.columns() + assert columns["base_col"].doc == "Base column documentation." + assert columns["child_col"].doc == "Child column documentation." + + +def test_column_docstring_overridden_in_child() -> None: + """Test that docstrings can be overridden in child schemas.""" + + class ParentSchema(dy.Schema): + col1 = dy.String() + """Parent documentation.""" + + class ChildSchema(ParentSchema): + col1 = dy.String() + """Child documentation.""" + + parent_columns = ParentSchema.columns() + child_columns = ChildSchema.columns() + + # Each schema should have its own docstring + assert parent_columns["col1"].doc == "Parent documentation." + assert child_columns["col1"].doc == "Child documentation." + + +def test_column_docstring_serialization() -> None: + """Test that column docstrings are preserved in serialization.""" + + class MySchema(dy.Schema): + col1 = dy.String(nullable=False) + """Documentation for col1.""" + + col2 = dy.Integer() + + # Serialize the schema + serialized = MySchema.serialize() + + # Deserialize it back + from dataframely.schema import deserialize_schema + + deserialized = deserialize_schema(serialized) + + # Check that docstrings are preserved + columns = deserialized.columns() + assert columns["col1"].doc == "Documentation for col1." + assert columns["col2"].doc is None + + +def test_column_docstring_validation_not_affected() -> None: + """Test that column docstrings don't affect validation.""" + + class MySchema(dy.Schema): + name = dy.String(nullable=False) + """Name documentation.""" + + age = dy.UInt8() + """Age documentation.""" + + # Create a valid DataFrame + df = pl.DataFrame( + { + "name": ["Alice", "Bob"], + "age": [30, 25], + } + ) + + # Validation should work normally + result = MySchema.validate(df, cast=True) + assert len(result) == 2 + + +def test_column_docstring_in_repr() -> None: + """Test that doc parameter appears in column repr when set.""" + + # Create a schema with docstrings + class MySchema(dy.Schema): + col_with_doc = dy.String() + """Test documentation.""" + + col_without_doc = dy.String() + + columns = MySchema.columns() + col_with_doc = columns["col_with_doc"] + col_without_doc = columns["col_without_doc"] + + # Doc should appear in repr when set + assert "doc=" in repr(col_with_doc) + # Doc should not appear when it's None (default value) + assert "doc=" not in repr(col_without_doc) + + +def test_column_docstring_empty_string() -> None: + """Test handling of empty docstrings.""" + + class MySchema(dy.Schema): + col1 = dy.String() + "" + + col2 = dy.Integer() + """""" + + columns = MySchema.columns() + # Empty strings should still be captured + assert columns["col1"].doc == "" + assert columns["col2"].doc == "" + + +def test_schema_matches_with_docstrings() -> None: + """Test that schema matching considers docstrings.""" + + class Schema1(dy.Schema): + col1 = dy.String() + """Doc 1.""" + + class Schema2(dy.Schema): + col1 = dy.String() + """Doc 1.""" + + class Schema3(dy.Schema): + col1 = dy.String() + """Doc 2.""" + + class Schema4(dy.Schema): + col1 = dy.String() + + # Same docstrings should match + assert Schema1.matches(Schema2) + + # Different docstrings should not match + assert not Schema1.matches(Schema3) + + # Missing vs present docstring should not match + assert not Schema1.matches(Schema4) + + +def test_column_docstring_with_primary_key() -> None: + """Test column docstrings work with primary key columns.""" + + class MySchema(dy.Schema): + id = dy.Integer(primary_key=True) + """The unique identifier.""" + + name = dy.String() + """The name field.""" + + columns = MySchema.columns() + assert columns["id"].doc == "The unique identifier." + assert columns["id"].primary_key is True + assert columns["name"].doc == "The name field."