Skip to content

Commit f3dcad8

Browse files
ElleNajtclaude
andcommitted
Add rich pretty printing support for PyTorch tensors
- Add tensor_repr() function in print_org_df.py that uses rich.pretty.Pretty - Override torch.Tensor.__repr__ when print_org_df.enable() is called - Temporarily restore original __repr__ during formatting to avoid recursion - Output uses rich indentation guides for better readability 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 20e9b07 commit f3dcad8

3 files changed

Lines changed: 150 additions & 0 deletions

File tree

python/print_org_df.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,125 @@ def disable_spark():
8484
except ImportError:
8585
DATAFRAME_IMAGE_AVAILABLE = False
8686

87+
TORCH_AVAILABLE = False
88+
_torch_original_str = None
89+
_torch_original_repr = None
90+
try:
91+
import torch
92+
93+
TORCH_AVAILABLE = True
94+
# Save original methods immediately at import time to avoid recursion
95+
_torch_original_str = torch.Tensor.__str__
96+
_torch_original_repr = torch.Tensor.__repr__
97+
except ImportError:
98+
TORCH_AVAILABLE = False
99+
100+
RICH_AVAILABLE = False
101+
try:
102+
from rich.console import Console
103+
from rich.panel import Panel
104+
from rich.table import Table
105+
from rich.text import Text
106+
107+
RICH_AVAILABLE = True
108+
except ImportError:
109+
RICH_AVAILABLE = False
110+
111+
112+
def tensor_repr(self):
113+
"""Simple rich repr for PyTorch tensors."""
114+
if not RICH_AVAILABLE:
115+
return _torch_original_repr(self)
116+
117+
from io import StringIO
118+
from rich.console import Console
119+
from rich.pretty import Pretty
120+
121+
# Temporarily restore original repr to avoid recursion
122+
torch.Tensor.__repr__ = _torch_original_repr
123+
try:
124+
# Create a console that outputs to string without ANSI codes
125+
output = StringIO()
126+
console = Console(file=output, force_terminal=False, width=120)
127+
console.print(Pretty(self, indent_guides=True))
128+
result = output.getvalue()
129+
finally:
130+
torch.Tensor.__repr__ = tensor_repr
131+
132+
return result
133+
134+
135+
def rich_tensor_repr(tensor):
136+
"""Format PyTorch tensor with rich library for better readability."""
137+
if not RICH_AVAILABLE:
138+
# Fallback to default repr if rich not available
139+
if _torch_original_repr is not None:
140+
return _torch_original_repr(tensor)
141+
return object.__repr__(tensor)
142+
143+
console = Console(width=120, force_terminal=False, force_jupyter=False)
144+
145+
# Use the original string func saved at module load time to avoid recursion
146+
original_str_func = (
147+
_torch_original_str if _torch_original_str is not None else object.__str__
148+
)
149+
150+
# Capture output to string
151+
with console.capture() as capture:
152+
# Create info table
153+
info_table = Table(show_header=False, box=None, padding=(0, 1))
154+
info_table.add_column("Property", style="cyan")
155+
info_table.add_column("Value", style="yellow")
156+
157+
# Basic info
158+
info_table.add_row("Shape", str(tuple(tensor.shape)))
159+
info_table.add_row("Dtype", str(tensor.dtype))
160+
info_table.add_row("Device", str(tensor.device))
161+
162+
# Gradient info
163+
if tensor.requires_grad:
164+
info_table.add_row("Requires Grad", "True")
165+
if tensor.grad_fn is not None:
166+
info_table.add_row("Grad Fn", str(tensor.grad_fn))
167+
168+
# Statistics for numeric tensors
169+
if tensor.numel() > 0 and tensor.dtype in [
170+
torch.float32,
171+
torch.float64,
172+
torch.float16,
173+
torch.bfloat16,
174+
]:
175+
try:
176+
info_table.add_row("Min", f"{tensor.min().item():.4f}")
177+
info_table.add_row("Max", f"{tensor.max().item():.4f}")
178+
info_table.add_row("Mean", f"{tensor.mean().item():.4f}")
179+
info_table.add_row("Std", f"{tensor.std().item():.4f}")
180+
except Exception:
181+
# Skip stats if they fail (NaNs, inf, empty tensors, etc.)
182+
# Better to show tensor without stats than crash the repr
183+
pass
184+
185+
console.print(info_table)
186+
187+
# Show tensor values (truncated if large)
188+
console.print("\n[bold]Values:[/bold]")
189+
190+
# Use torch's ORIGINAL repr for the values, but limit size
191+
if tensor.numel() <= 1000:
192+
console.print(original_str_func(tensor))
193+
else:
194+
# For large tensors, show a sample
195+
console.print(f"[dim](showing slice of {tensor.numel()} elements)[/dim]")
196+
if tensor.ndim == 1:
197+
console.print(original_str_func(tensor[:10]))
198+
elif tensor.ndim == 2:
199+
console.print(original_str_func(tensor[:5, :5]))
200+
else:
201+
# For higher dims, show first slice
202+
console.print(original_str_func(tensor[0]))
203+
204+
return capture.get()
205+
87206

88207
def image_repr(self, org_babel_filename, dpi=400):
89208
if not DATAFRAME_IMAGE_AVAILABLE:
@@ -172,6 +291,11 @@ def enable(repr_type, org_babel_filename=None, dpi=400):
172291

173292
if PYSPARK_AVAILABLE:
174293
SparkDataFrame.show = custom_spark_show
294+
295+
# Enable tensor pretty printing
296+
if TORCH_AVAILABLE and RICH_AVAILABLE:
297+
torch.Tensor.__repr__ = tensor_repr
298+
torch.Tensor.__str__ = tensor_repr
175299

176300
if repr_type == "image":
177301
for obj in [pd.DataFrame, pd.Series, pd.io.formats.style.Styler]:
@@ -202,6 +326,11 @@ def disable():
202326
obj.__repr__ = _original_repr[obj]
203327
obj.__str__ = _original_str[obj]
204328

329+
# Restore original tensor repr
330+
if TORCH_AVAILABLE and _torch_original_repr is not None:
331+
torch.Tensor.__repr__ = _torch_original_repr
332+
torch.Tensor.__str__ = _torch_original_str
333+
205334

206335
def is_enabled():
207336
return PANDAS_AVAILABLE and pd.DataFrame.__repr__ == org_repr

tests/babel-formatting.org

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,3 +466,23 @@ print(1)
466466
1
467467
:end:
468468

469+
470+
* Torch
471+
:PROPERTIES:
472+
:header-args: :results output drawer :python "nix-shell --run python" :tangle :session torch :timer-show no
473+
:END:
474+
475+
476+
#+begin_src python
477+
import torch
478+
x = torch.randn(3, 3)
479+
print(x)
480+
#+end_src
481+
482+
#+RESULTS:
483+
:results:
484+
tensor([[-1.0608, 0.0756, 0.4520],
485+
│ │ [-0.4603, -0.6774, -0.1871],
486+
│ │ [-0.6745, 0.0472, -0.8954]])
487+
:end:
488+

tests/shell.nix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ let
1919
pandas
2020
numpy
2121
scikit-learn
22+
torch
2223
matplotlib
2324
seaborn
2425
polars

0 commit comments

Comments
 (0)