Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions src/dt_browser/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SelectFromTable,
)
from dt_browser.bookmarks import Bookmarks
from dt_browser.column_metadata import ColumnMetadata
from dt_browser.column_selector import ColumnSelector
from dt_browser.custom_table import CustomTable, _color_name, polars_list_to_string
from dt_browser.expression_box import ExpressionBox
Expand Down Expand Up @@ -248,9 +249,8 @@ def compute_active_search_idx_display(self):
class RowDetail(Widget, can_focus=False, can_focus_children=False):
DEFAULT_CSS = """
RowDetail {
width: auto;
max_width: 50%;
min_width: 30%;
width: 100%;
height: 1fr;
padding: 0 1;
border: tall $primary;
}
Expand Down Expand Up @@ -288,14 +288,43 @@ def watch_row_df(self):
assert self._schema is not None
display_df = display_df.join(self._schema, on=["Field"]).select(["Field", "dtype", "Value"])
self._dt.set_dt(display_df, display_df.with_row_index(name=INDEX_COL).select([INDEX_COL]))
self.styles.width = self._dt.virtual_size.width + self.gutter.width + 1
self._dt.refresh()
if isinstance(self.parent, DetailPanel):
self.parent.update_width()
# self._dt.go_to_cell(coord)

@property
def content_width(self) -> int:
return self._dt.virtual_size.width + self.gutter.width + 1

def compose(self):
yield self._dt


class DetailPanel(Widget, can_focus=False):
DEFAULT_CSS = """
DetailPanel {
max-width: 50%;
min-width: 30%;
layout: vertical;
}
"""

def __init__(self, row_detail: RowDetail, column_metadata: ColumnMetadata, *args, **kwargs):
super().__init__(*args, **kwargs)
self._row_detail = row_detail
self._column_metadata = column_metadata

def update_width(self) -> None:
row_detail_width = self._row_detail.content_width if not self._row_detail.row_df.is_empty() else 0
meta_width = self._column_metadata.content_size.width + self._column_metadata.gutter.width
self.styles.width = max(row_detail_width, meta_width)

def compose(self):
yield self._row_detail
yield self._column_metadata


def from_file_path(path: pathlib.Path, has_header: bool = True) -> pl.DataFrame:

if path.suffix in [".arrow", ".feather"]:
Expand Down Expand Up @@ -342,6 +371,7 @@ class DtBrowser(Widget): # pylint: disable=too-many-public-methods,too-many-ins
current_filter = reactive[str | None](None)

cur_row = reactive(0)
cur_col = reactive(0)
cur_total_rows = reactive(0)
total_rows = reactive(0)

Expand Down Expand Up @@ -400,6 +430,9 @@ def __init__(
self._ts_col_selector.styles.width = 1

self._row_detail = RowDetail()
self._column_metadata = ColumnMetadata()
self._column_metadata.set_source_df(self._filtered_dt)
self._detail_panel = DetailPanel(self._row_detail, self._column_metadata)

self._color_by_cache: LRUCache[tuple[str, ...], pl.Series] = LRUCache(5)
self._last_message_ts = time.time()
Expand Down Expand Up @@ -623,10 +656,10 @@ async def action_show_save(self):

async def watch_show_row_detail(self):
if not self.show_row_detail:
if existing := self.query(RowDetail):
if existing := self.query(DetailPanel):
existing.remove()
elif not self._display_dt.is_empty():
await self.query_one("#main_hori", Horizontal).mount(self._row_detail)
await self.query_one("#main_hori", Horizontal).mount(self._detail_panel)

async def action_show_bookmarks(self):
await self.mount(self._bookmarks, before=self.query_one(TableFooter))
Expand All @@ -644,6 +677,8 @@ async def action_timestamp_selector(self):
def _set_filtered_dt(self, filtered_dt: pl.DataFrame, filtered_meta: pl.DataFrame, **kwargs):
self._filtered_dt = filtered_dt
self._meta_dt = filtered_meta
self._column_metadata.set_source_df(self._filtered_dt)
self._column_metadata.invalidate_cache()
self._set_active_dt(self._filtered_dt, **kwargs)

def _set_active_dt(self, active_dt: pl.DataFrame, new_row: int | None = None):
Expand Down Expand Up @@ -714,10 +749,20 @@ def enable_select_from_table(self, event: SelectFromTable):
@on(CustomTable.CellHighlighted, selector="#main_table")
async def handle_cell_highlight(self, event: CustomTable.CellHighlighted):
self.cur_row = event.coordinate.row
col = event.coordinate.column
if col != self.cur_col:
self.cur_col = col

def watch_cur_row(self):
self._row_detail.row_df = self._display_dt[self.cur_row]

def watch_cur_col(self):
if self._display_dt.is_empty() or self.cur_col >= len(self._display_dt.columns):
return
col_name = self._display_dt.columns[self.cur_col]
dtype = self._display_dt.schema[col_name]
self._column_metadata.column_info = (col_name, dtype)

@on(CustomTable.CellSelected, selector="#main_table")
def handle_cell_select(self, event: CustomTable.CellSelected):
if self._select_interest:
Expand Down Expand Up @@ -857,6 +902,9 @@ def on_mount(self):
self.cur_total_rows = len(self._display_dt)
self.total_rows = len(self._original_dt)
self._row_detail.row_df = self._display_dt[0]
if not self._display_dt.is_empty():
col_name = self._display_dt.columns[0]
self._column_metadata.column_info = (col_name, self._display_dt.schema[col_name])
if self.removed_cols:
err_str = ", ".join(f"{k}: {v}" for k, v in self.removed_cols.items())
self.notify(
Expand Down
134 changes: 134 additions & 0 deletions src/dt_browser/column_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import polars as pl
from rich.table import Table as RichTable
from textual import work
from textual.reactive import reactive
from textual.widget import Widget
from textual.widgets import Static


def _categorical_stats(series: pl.Series) -> list[tuple[str, str]]:
n_unique = series.n_unique()
stats: list[tuple[str, str]] = [("Unique values", str(n_unique))]
val_col = series.name
vc = series.value_counts().sort(["count", val_col], descending=[True, False]).head(10)
for row in vc.iter_rows(named=True):
stats.append((f" {row[val_col]}", str(row["count"])))
return stats


def _numeric_stats(series: pl.Series) -> list[tuple[str, str]]:
s = series.drop_nulls()
if s.is_empty():
return [("", "No data")]
stats = [
("Min", str(s.min())),
("Q1", str(s.quantile(0.25))),
("Median", str(s.median())),
("Q3", str(s.quantile(0.75))),
("Max", str(s.max())),
]
if s.dtype.is_float():
nan_count = s.is_nan().sum()
if nan_count > 0:
stats.append(("NaN", str(nan_count)))
return stats


def _temporal_stats(series: pl.Series) -> list[tuple[str, str]]:
s = series.drop_nulls()
if s.is_empty():
return [("", "No data")]
return [
("Min", str(s.min())),
("Max", str(s.max())),
]


def _boolean_stats(series: pl.Series) -> list[tuple[str, str]]:
true_count = series.sum()
null_count = series.null_count()
false_count = len(series) - (true_count or 0) - null_count
return [
("True", str(true_count)),
("False", str(false_count)),
]


def compute_column_stats(series: pl.Series) -> list[tuple[str, str]]:
dtype = series.dtype
if dtype == pl.Categorical:
stats = _categorical_stats(series)
elif dtype.is_numeric():
stats = _numeric_stats(series)
elif dtype.is_temporal():
stats = _temporal_stats(series)
elif dtype.is_(pl.Boolean):
stats = _boolean_stats(series)
else:
return []
null_count = series.null_count()
if null_count > 0:
stats.append(("Null", str(null_count)))
return stats


class ColumnMetadata(Widget, can_focus=False, can_focus_children=False):
DEFAULT_CSS = """
ColumnMetadata {
width: 100%;
height: auto;
padding: 0 1;
border: tall $primary;
}
"""
column_info: reactive[tuple[str, pl.DataType] | None] = reactive(None)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.border_title = "Column Metadata"
self._source_df: pl.DataFrame = pl.DataFrame()
self._cache: dict[str, list[tuple[str, str]]] = {}
self._static = Static("")

def set_source_df(self, df: pl.DataFrame) -> None:
self._source_df = df

def invalidate_cache(self) -> None:
self._cache.clear()

def _render_stats(self, col_name: str, stats: list[tuple[str, str]]) -> None:
self.border_title = f"Column: {col_name}"
if not stats:
self._static.update("")
return
table = RichTable(show_header=False, box=None, padding=(0, 1), expand=True)
table.add_column("Stat", no_wrap=True)
table.add_column("Value", no_wrap=True, justify="right")
for label, value in stats:
table.add_row(label, value)
self._static.update(table)
if self.parent is not None and hasattr(self.parent, "update_width"):
self.parent.update_width()

def watch_column_info(self) -> None:
if self.column_info is None or self._source_df.is_empty():
return
col_name, _ = self.column_info
if col_name not in self._source_df.columns:
return
if col_name in self._cache:
self._render_stats(col_name, self._cache[col_name])
else:
self.border_title = f"Column: {col_name}"
self._static.update("Computing...")
self._compute_stats(col_name)

@work(exclusive=True)
async def _compute_stats(self, col_name: str) -> None:
series = self._source_df[col_name]
stats = compute_column_stats(series)
self._cache[col_name] = stats
self._render_stats(col_name, stats)

def compose(self):
yield self._static
Loading
Loading