@@ -614,7 +614,7 @@ def __init__(self, kernel_group, reason=None):
614614 self .target_cse_override = contextvars .ContextVar (f"Handler_cse_override_{ instance_id } " , default = self .cse )
615615 self ._nested_context_depth = 0
616616
617- def set_ranges (self , lengths , reduction_lengths ):
617+ def set_ranges (self , lengths , reduction_lengths , index_names = None ):
618618 if self .call_ranges :
619619 assert self .call_ranges == tuple (lengths ) + tuple (
620620 reduction_lengths
@@ -623,7 +623,12 @@ def set_ranges(self, lengths, reduction_lengths):
623623 else :
624624 self .call_ranges = tuple (lengths ) + tuple (reduction_lengths )
625625 self .ranges = [self .rename_indexing (x ) for x in self .call_ranges ]
626- self .itervars = [sympy .Symbol (f"index{ n } " ) for n in range (len (self .ranges ))]
626+ if index_names is None :
627+ self .itervars = [sympy .Symbol (f"index{ n } " ) for n in range (len (self .ranges ))]
628+ else :
629+ assert len (index_names ) == len (self .ranges ), f"Index names length mismatch: { len (index_names )} != { len (self .ranges )} "
630+ self .itervars = [sympy .Symbol (str (n )) for n in index_names ]
631+
627632 self .itervar_cses = {str (index ) : self .register_var_cse (str (index ), 1 , "index" ) for index in self .itervars }
628633 self .reduction_depth = len (lengths )
629634 return (
@@ -867,18 +872,22 @@ def rename_indexing(self, index) -> sympy.Expr:
867872 def override_buffer_cse (self , * , buffer = None , cse = None ):
868873 buffer_override = self .target_buffer_override
869874 cse_override = self .target_cse_override
870- target_buffer = target_cse = None
875+ buffer_token = cse_token = None
871876 try :
877+ # Store tokens for proper restoration in nested contexts
878+ # contextvars.set() returns the previous value (token) which can be used for reset()
872879 if buffer is not None :
873- target_buffer = buffer_override .set (buffer )
880+ buffer_token = buffer_override .set (buffer )
874881 if cse is not None :
875- target_cse = cse_override .set (cse )
882+ cse_token = cse_override .set (cse )
876883 yield self
877884 finally :
878- if target_cse is not None :
879- cse_override .reset (target_cse )
880- if target_buffer is not None :
881- buffer_override .reset (target_buffer )
885+ # Restore using tokens - contextvars automatically handles nested contexts
886+ # Each level restores to its own previous value
887+ if cse_token is not None :
888+ cse_override .reset (cse_token )
889+ if buffer_token is not None :
890+ buffer_override .reset (buffer_token )
882891
883892 def __enter__ (self ):
884893 class CSEProxy :
0 commit comments