@@ -84,6 +84,125 @@ def disable_spark():
8484except ImportError :
8585 DATAFRAME_IMAGE_AVAILABLE = False
8686
87+ TORCH_AVAILABLE = False
88+ _torch_original_str = None
89+ _torch_original_repr = None
90+ try :
91+ import torch
92+
93+ TORCH_AVAILABLE = True
94+ # Save original methods immediately at import time to avoid recursion
95+ _torch_original_str = torch .Tensor .__str__
96+ _torch_original_repr = torch .Tensor .__repr__
97+ except ImportError :
98+ TORCH_AVAILABLE = False
99+
100+ RICH_AVAILABLE = False
101+ try :
102+ from rich .console import Console
103+ from rich .panel import Panel
104+ from rich .table import Table
105+ from rich .text import Text
106+
107+ RICH_AVAILABLE = True
108+ except ImportError :
109+ RICH_AVAILABLE = False
110+
111+
112+ def tensor_repr (self ):
113+ """Simple rich repr for PyTorch tensors."""
114+ if not RICH_AVAILABLE :
115+ return _torch_original_repr (self )
116+
117+ from io import StringIO
118+ from rich .console import Console
119+ from rich .pretty import Pretty
120+
121+ # Temporarily restore original repr to avoid recursion
122+ torch .Tensor .__repr__ = _torch_original_repr
123+ try :
124+ # Create a console that outputs to string without ANSI codes
125+ output = StringIO ()
126+ console = Console (file = output , force_terminal = False , width = 120 )
127+ console .print (Pretty (self , indent_guides = True ))
128+ result = output .getvalue ()
129+ finally :
130+ torch .Tensor .__repr__ = tensor_repr
131+
132+ return result
133+
134+
135+ def rich_tensor_repr (tensor ):
136+ """Format PyTorch tensor with rich library for better readability."""
137+ if not RICH_AVAILABLE :
138+ # Fallback to default repr if rich not available
139+ if _torch_original_repr is not None :
140+ return _torch_original_repr (tensor )
141+ return object .__repr__ (tensor )
142+
143+ console = Console (width = 120 , force_terminal = False , force_jupyter = False )
144+
145+ # Use the original string func saved at module load time to avoid recursion
146+ original_str_func = (
147+ _torch_original_str if _torch_original_str is not None else object .__str__
148+ )
149+
150+ # Capture output to string
151+ with console .capture () as capture :
152+ # Create info table
153+ info_table = Table (show_header = False , box = None , padding = (0 , 1 ))
154+ info_table .add_column ("Property" , style = "cyan" )
155+ info_table .add_column ("Value" , style = "yellow" )
156+
157+ # Basic info
158+ info_table .add_row ("Shape" , str (tuple (tensor .shape )))
159+ info_table .add_row ("Dtype" , str (tensor .dtype ))
160+ info_table .add_row ("Device" , str (tensor .device ))
161+
162+ # Gradient info
163+ if tensor .requires_grad :
164+ info_table .add_row ("Requires Grad" , "True" )
165+ if tensor .grad_fn is not None :
166+ info_table .add_row ("Grad Fn" , str (tensor .grad_fn ))
167+
168+ # Statistics for numeric tensors
169+ if tensor .numel () > 0 and tensor .dtype in [
170+ torch .float32 ,
171+ torch .float64 ,
172+ torch .float16 ,
173+ torch .bfloat16 ,
174+ ]:
175+ try :
176+ info_table .add_row ("Min" , f"{ tensor .min ().item ():.4f} " )
177+ info_table .add_row ("Max" , f"{ tensor .max ().item ():.4f} " )
178+ info_table .add_row ("Mean" , f"{ tensor .mean ().item ():.4f} " )
179+ info_table .add_row ("Std" , f"{ tensor .std ().item ():.4f} " )
180+ except Exception :
181+ # Skip stats if they fail (NaNs, inf, empty tensors, etc.)
182+ # Better to show tensor without stats than crash the repr
183+ pass
184+
185+ console .print (info_table )
186+
187+ # Show tensor values (truncated if large)
188+ console .print ("\n [bold]Values:[/bold]" )
189+
190+ # Use torch's ORIGINAL repr for the values, but limit size
191+ if tensor .numel () <= 1000 :
192+ console .print (original_str_func (tensor ))
193+ else :
194+ # For large tensors, show a sample
195+ console .print (f"[dim](showing slice of { tensor .numel ()} elements)[/dim]" )
196+ if tensor .ndim == 1 :
197+ console .print (original_str_func (tensor [:10 ]))
198+ elif tensor .ndim == 2 :
199+ console .print (original_str_func (tensor [:5 , :5 ]))
200+ else :
201+ # For higher dims, show first slice
202+ console .print (original_str_func (tensor [0 ]))
203+
204+ return capture .get ()
205+
87206
88207def image_repr (self , org_babel_filename , dpi = 400 ):
89208 if not DATAFRAME_IMAGE_AVAILABLE :
@@ -172,6 +291,11 @@ def enable(repr_type, org_babel_filename=None, dpi=400):
172291
173292 if PYSPARK_AVAILABLE :
174293 SparkDataFrame .show = custom_spark_show
294+
295+ # Enable tensor pretty printing
296+ if TORCH_AVAILABLE and RICH_AVAILABLE :
297+ torch .Tensor .__repr__ = tensor_repr
298+ torch .Tensor .__str__ = tensor_repr
175299
176300 if repr_type == "image" :
177301 for obj in [pd .DataFrame , pd .Series , pd .io .formats .style .Styler ]:
@@ -202,6 +326,11 @@ def disable():
202326 obj .__repr__ = _original_repr [obj ]
203327 obj .__str__ = _original_str [obj ]
204328
329+ # Restore original tensor repr
330+ if TORCH_AVAILABLE and _torch_original_repr is not None :
331+ torch .Tensor .__repr__ = _torch_original_repr
332+ torch .Tensor .__str__ = _torch_original_str
333+
205334
206335def is_enabled ():
207336 return PANDAS_AVAILABLE and pd .DataFrame .__repr__ == org_repr
0 commit comments