@@ -529,8 +529,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
529529 return pytato_program , name_in_program_to_tags , name_in_program_to_axes
530530
531531
532- def _args_to_device_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ,
533- fn_name = "<unknown>" ):
532+ def _args_to_device_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
534533 input_kwargs_for_loopy = {}
535534
536535 for arg_id , arg in arg_id_to_arg .items ():
@@ -551,20 +550,32 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg,
551550 # got a frozen array => do nothing
552551 pass
553552 elif isinstance (arg , pt .Array ):
554- # got an array expression => abort
555- raise ValueError (
556- f"Argument '{ arg_id } ' to the '{ fn_name } ' compiled function is a"
557- " pytato array expression. Evaluating it just-in-time"
558- " potentially causes a significant overhead on each call to the"
559- " function and is therefore unsupported. "
560- )
553+ # got an array expression => evaluate it
554+ from warnings import warn
555+ warn (f"Argument array '{ arg_id } ' to a compiled function is "
556+ "unevaluated. Evaluating just-in-time, at "
557+ "considerable expense. This is deprecated and will stop "
558+ "working in 2023. To avoid this warning, force evaluation "
559+ "of all arguments via freeze/thaw." ,
560+ DeprecationWarning , stacklevel = 4 )
561+
562+ arg = actx .freeze (arg )
561563 else :
562564 raise NotImplementedError (type (arg ))
563565
564566 input_kwargs_for_loopy [input_id_to_name_in_program [arg_id ]] = arg
565567
566568 return input_kwargs_for_loopy
567569
570+
571+ def _args_to_cl_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
572+ from warnings import warn
573+ warn ("_args_to_cl_buffer has been renamed to"
574+ " _args_to_device_buffers. This will be"
575+ " an error in 2023." , DeprecationWarning , stacklevel = 2 )
576+ return _args_to_device_buffers (actx , input_id_to_name_in_program ,
577+ arg_id_to_arg )
578+
568579# }}}
569580
570581
@@ -620,7 +631,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
620631 type of the callable.
621632 """
622633 actx : PytatoPyOpenCLArrayContext
623- pytato_program : pt .target .loopy . BoundPyOpenCLExecutable
634+ pytato_program : pt .target .BoundProgram
624635 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
625636 output_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
626637 name_in_program_to_tags : Mapping [str , frozenset [Tag ]]
@@ -631,10 +642,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
631642 from .utils import get_cl_axes_from_pt_axes
632643 from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
633644
634- fn_name = self .pytato_program .program .entrypoint
635-
636645 input_kwargs_for_loopy = _args_to_device_buffers (
637- self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
646+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
638647
639648 evt , out_dict = self .pytato_program (queue = self .actx .queue ,
640649 allocator = self .actx .allocator ,
@@ -665,7 +674,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction):
665674 Name of the output array in the program.
666675 """
667676 actx : PytatoPyOpenCLArrayContext
668- pytato_program : pt .target .loopy . BoundPyOpenCLExecutable
677+ pytato_program : pt .target .BoundProgram
669678 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
670679 output_tags : frozenset [Tag ]
671680 output_axes : tuple [pt .Axis , ...]
@@ -675,10 +684,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
675684 from .utils import get_cl_axes_from_pt_axes
676685 from arraycontext .impl .pyopencl .taggable_cl_array import to_tagged_cl_array
677686
678- fn_name = self .pytato_program .program .entrypoint
679-
680687 input_kwargs_for_loopy = _args_to_device_buffers (
681- self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
688+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
682689
683690 evt , out_dict = self .pytato_program (queue = self .actx .queue ,
684691 allocator = self .actx .allocator ,
@@ -716,18 +723,16 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
716723 type of the callable.
717724 """
718725 actx : PytatoJAXArrayContext
719- pytato_program : pt .target .python . BoundJAXPythonProgram
726+ pytato_program : pt .target .BoundProgram
720727 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
721728 output_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
722729 name_in_program_to_tags : Mapping [str , frozenset [Tag ]]
723730 name_in_program_to_axes : Mapping [str , tuple [pt .Axis , ...]]
724731 output_template : ArrayContainer
725732
726733 def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
727- fn_name = self .pytato_program .entrypoint
728-
729734 input_kwargs_for_loopy = _args_to_device_buffers (
730- self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
735+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
731736
732737 out_dict = self .pytato_program (** input_kwargs_for_loopy )
733738
@@ -749,17 +754,15 @@ class CompiledJAXFunctionReturningArray(CompiledFunction):
749754 Name of the output array in the program.
750755 """
751756 actx : PytatoJAXArrayContext
752- pytato_program : pt .target .python . BoundJAXPythonProgram
757+ pytato_program : pt .target .BoundProgram
753758 input_id_to_name_in_program : Mapping [tuple [Hashable , ...], str ]
754759 output_tags : frozenset [Tag ]
755760 output_axes : tuple [pt .Axis , ...]
756761 output_name : str
757762
758763 def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
759- fn_name = self .pytato_program .entrypoint
760-
761764 input_kwargs_for_loopy = _args_to_device_buffers (
762- self .actx , self .input_id_to_name_in_program , arg_id_to_arg , fn_name )
765+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
763766
764767 _evt , out_dict = self .pytato_program (** input_kwargs_for_loopy )
765768
0 commit comments