Skip to content

Commit d8071ea

Browse files
avoid finding arg_id_to_descr in single_version_only
1 parent b8c0e25 commit d8071ea

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -323,18 +323,31 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
323323
:attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
324324
The intermediary pytato DAG for *args* is memoized in *self*.
325325
"""
326-
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
327-
args, kwargs)
328-
329-
try:
330-
compiled_f = self.program_cache[arg_id_to_descr]
331-
except KeyError as e:
332-
if self.single_version_only and self.program_cache:
333-
raise ValueError(
334-
f"Function '{self.f.__name__}' to be compiled "
335-
"was already compiled previously with different arguments.") from e
326+
if not self.single_version_only:
327+
arg_id_to_arg, arg_id_to_descr = \
328+
_get_arg_id_to_arg_and_arg_id_to_descr(args, kwargs)
329+
330+
try:
331+
compiled_f = self.program_cache[arg_id_to_descr]
332+
except KeyError:
333+
pass
334+
else:
335+
return compiled_f(arg_id_to_arg)
336336
else:
337-
return compiled_f(arg_id_to_arg)
337+
assert len(self.program_cache) <= 1
338+
339+
try:
340+
arg_id_to_descr, compiled_f = self.program_cache.popitem()
341+
except KeyError:
342+
pass
343+
else:
344+
if __debug__:
345+
current_arg_id_to_arg, current_arg_id_to_descr = \
346+
_get_arg_id_to_arg_and_arg_id_to_descr(args, kwargs)
347+
assert arg_id_to_descr == current_arg_id_to_descr
348+
assert self.arg_id_to_arg == current_arg_id_to_arg
349+
350+
return compiled_f(self.arg_id_to_arg)
338351

339352
dict_of_named_arrays = {}
340353
output_id_to_name_in_program = {}
@@ -377,6 +390,9 @@ def _as_dict_of_named_arrays(keys, ary):
377390
output_id_to_name_in_program=output_id_to_name_in_program,
378391
output_template=output_template)
379392

393+
if self.single_version_only:
394+
self.arg_id_to_arg = arg_id_to_arg
395+
380396
self.program_cache[arg_id_to_descr] = compiled_func
381397
return compiled_func(arg_id_to_arg)
382398

0 commit comments

Comments
 (0)