@@ -323,18 +323,31 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
323323 :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
324324 The intermediary pytato DAG for *args* is memoized in *self*.
325325 """
326- arg_id_to_arg , arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr (
327- args , kwargs )
328-
329- try :
330- compiled_f = self . program_cache [ arg_id_to_descr ]
331- except KeyError as e :
332- if self . single_version_only and self . program_cache :
333- raise ValueError (
334- f"Function ' { self . f . __name__ } ' to be compiled "
335- "was already compiled previously with different arguments." ) from e
326+ if not self . single_version_only :
327+ arg_id_to_arg , arg_id_to_descr = \
328+ _get_arg_id_to_arg_and_arg_id_to_descr ( args , kwargs )
329+
330+ try :
331+ compiled_f = self . program_cache [ arg_id_to_descr ]
332+ except KeyError :
333+ pass
334+ else :
335+ return compiled_f ( arg_id_to_arg )
336336 else :
337- return compiled_f (arg_id_to_arg )
337+ assert len (self .program_cache ) <= 1
338+
339+ try :
340+ arg_id_to_descr , compiled_f = self .program_cache .popitem ()
341+ except KeyError :
342+ pass
343+ else :
344+ if __debug__ :
345+ current_arg_id_to_arg , current_arg_id_to_descr = \
346+ _get_arg_id_to_arg_and_arg_id_to_descr (args , kwargs )
347+ assert arg_id_to_descr == current_arg_id_to_descr
348+ assert self .arg_id_to_arg == current_arg_id_to_arg
349+
350+ return compiled_f (self .arg_id_to_arg )
338351
339352 dict_of_named_arrays = {}
340353 output_id_to_name_in_program = {}
@@ -377,6 +390,9 @@ def _as_dict_of_named_arrays(keys, ary):
377390 output_id_to_name_in_program = output_id_to_name_in_program ,
378391 output_template = output_template )
379392
393+ if self .single_version_only :
394+ self .arg_id_to_arg = arg_id_to_arg
395+
380396 self .program_cache [arg_id_to_descr ] = compiled_func
381397 return compiled_func (arg_id_to_arg )
382398
0 commit comments