Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions dataframely/_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

import ast
import inspect
import sys
import textwrap
from abc import ABCMeta
Expand Down Expand Up @@ -30,6 +32,58 @@
# --------------------------------------- UTILS -------------------------------------- #


def _extract_column_docstrings(cls: type) -> dict[str, str]:
"""Extract docstrings for class attributes from source code.

This function parses the source code of a class to find string literals
that immediately follow attribute assignments. These are treated as
documentation strings for those attributes.

Args:
cls: The class to extract docstrings from.

Returns:
A dictionary mapping attribute names to their docstrings.
"""
try:
source = inspect.getsource(cls)
# Dedent to handle indented class definitions
tree = ast.parse(textwrap.dedent(source))

# Find the class definition
class_def = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_def = node
break

if not class_def:
return {}

# Extract docstrings that appear after assignments
docstrings = {}
for i in range(len(class_def.body) - 1):
current = class_def.body[i]
next_stmt = class_def.body[i + 1]

# Check if current is an assignment and next is a string constant
if (
isinstance(current, ast.Assign)
and isinstance(next_stmt, ast.Expr)
and isinstance(next_stmt.value, ast.Constant)
and isinstance(next_stmt.value.value, str)
):
# Get the target name(s)
for target in current.targets:
if isinstance(target, ast.Name):
docstrings[target.id] = next_stmt.value.value

return docstrings
except (OSError, TypeError, SyntaxError):
# Source not available or cannot be parsed
return {}


def _build_rules(
custom: dict[str, Rule], columns: dict[str, Column], *, with_cast: bool
) -> dict[str, Rule]:
Expand Down Expand Up @@ -104,6 +158,21 @@ def __new__(
namespace[_COLUMN_ATTR] = result.columns
cls = super().__new__(mcs, name, bases, namespace, *args, **kwargs)

# Extract and attach docstrings to columns
docstrings = _extract_column_docstrings(cls)
for col_name, col in result.columns.items():
# Use the original attribute name (not alias) to match docstrings
original_name = None
for attr, value in namespace.items():
if value is col:
original_name = attr
break

# If we found a docstring for this column and it doesn't already have one,
# attach it
if original_name and original_name in docstrings and col.doc is None:
col.doc = docstrings[original_name]

# Assign rules retroactively as we only encounter rule factories in the result
rules = {name: factory.make(cls) for name, factory in result.rules.items()}
setattr(cls, _RULE_ATTR, rules)
Expand Down
46 changes: 37 additions & 9 deletions dataframely/columns/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
check: Check | None = None,
alias: str | None = None,
metadata: dict[str, Any] | None = None,
doc: str | None = None,
):
"""
Args:
Expand All @@ -70,6 +71,9 @@ def __init__(
this option does _not_ allow to refer to the column with two different
names, the specified alias is the only valid name.
metadata: A dictionary of metadata to attach to the column.
doc: A documentation string for the column. This can be automatically
extracted from a docstring placed immediately after the column definition
in a schema class.
"""

if nullable and primary_key:
Expand All @@ -80,6 +84,7 @@ def __init__(
self.check = check
self.alias = alias
self.metadata = metadata
self.doc = doc
# The name may be overridden by the schema on column access.
self._name = ""

Expand Down Expand Up @@ -299,7 +304,7 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]:
if self.__class__.__name__ not in _TYPE_MAPPING:
raise ValueError("Cannot serialize non-native dataframely column types.")

return {
result = {
"column_type": self.__class__.__name__,
**{
param: (
Expand All @@ -312,6 +317,11 @@ def as_dict(self, expr: pl.Expr) -> dict[str, Any]:
},
}

# Always include doc from the base Column class even if not in subclass signature
result["doc"] = self.doc

return result

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
"""Read the column definition from a dictionary.
Expand All @@ -325,13 +335,22 @@ def from_dict(cls, data: dict[str, Any]) -> Self:
Attention:
This method is only intended for internal use.
"""
return cls(
**{
k: (cast(Any, _check_from_expr(v)) if k == "check" else v)
for k, v in data.items()
if k != "column_type"
}
)
# Extract doc separately since it may not be in the subclass signature
doc_value = data.get("doc")

# Create the column with parameters that match its __init__ signature
column_data = {
k: (cast(Any, _check_from_expr(v)) if k == "check" else v)
for k, v in data.items()
if k not in ("column_type", "doc")
}

column = cls(**column_data)

# Set doc attribute directly if it was in the serialized data
column.doc = doc_value

return column

# ----------------------------------- EQUALITY ----------------------------------- #

Expand All @@ -350,7 +369,8 @@ def matches(self, other: Column, expr: pl.Expr) -> bool:
return False

attributes = inspect.signature(self.__class__.__init__)
return all(
# Check all attributes in the signature
sig_match = all(
self._attributes_match(
getattr(self, attr), getattr(other, attr), attr, expr
)
Expand All @@ -361,6 +381,9 @@ def matches(self, other: Column, expr: pl.Expr) -> bool:
if attr not in ("self", "alias")
)

# Also check the doc attribute from the base Column class
return sig_match and self.doc == other.doc

def _attributes_match(
self, lhs: Any, rhs: Any, name: str, column_expr: pl.Expr
) -> bool:
Expand All @@ -384,6 +407,11 @@ def __repr__(self) -> str:
getattr(self, attribute) == param_details.default
)
]

# Also include doc from base Column class if it's not None
if self.doc is not None:
parts.append(f"doc={repr(self.doc)}")

return f"{self.__class__.__name__}({', '.join(parts)})"

def __str__(self) -> str:
Expand Down
Loading
Loading