diff --git a/dali/python/nvidia/dali/experimental/dynamic/_compile.py b/dali/python/nvidia/dali/experimental/dynamic/_compile.py index 44144d0e080..26b43dbc821 100644 --- a/dali/python/nvidia/dali/experimental/dynamic/_compile.py +++ b/dali/python/nvidia/dali/experimental/dynamic/_compile.py @@ -345,7 +345,11 @@ def _matches(self, actual: Any, expected: Any) -> bool: ) if expected is None: return actual is None - return not isinstance(actual, Batch) and actual == expected + if isinstance(actual, Batch): + return False + + result = actual == expected + return result if isinstance(result, bool) else np.all(result).item() @_nvtx_range("Getting compiled result") def get_compiled_result( diff --git a/dali/python/nvidia/dali/experimental/dynamic/_source_analysis.py b/dali/python/nvidia/dali/experimental/dynamic/_source_analysis.py index a88fe39ecb3..305b9f54d03 100644 --- a/dali/python/nvidia/dali/experimental/dynamic/_source_analysis.py +++ b/dali/python/nvidia/dali/experimental/dynamic/_source_analysis.py @@ -13,6 +13,7 @@ # limitations under the License. import ast +import functools import inspect import itertools import linecache @@ -36,7 +37,7 @@ ) from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType -from ._call_site import CodeLoc +from ._call_site import CodeLoc, resolve_callsite_frame from ._compile import CompiledBatch, CompileRef from ._device import Device from ._type import DType @@ -168,6 +169,12 @@ def _unpack_bindings_impl(lhs: Sequence[cst.BaseElement], rhs: Sequence[cst.Base return bindings +@dataclass(frozen=True, slots=True) +class Binding: + in_scope: bool + rhs: cst.BaseExpression | None # None if binding is parameter + + @dataclass(frozen=True, slots=True) class ModuleInfo: """Per-file parsed libcst data plus the queries classification needs over it.""" @@ -202,31 +209,33 @@ def _resolve_call(self, frame: types.FrameType) -> cst.Call | None: candidates = self.calls_by_line.get(frame.f_lineno, ()) return candidates[0] if len(candidates) == 1 else None - def local_rhs(self, name_node: cst.Name) -> cst.BaseExpression | None: - """The RHS of `name_node`'s single capturable function-local binding, else None.""" + def binding(self, name_node: cst.Name) -> Binding | None: + """Return the single function-scope binding of `name_node`, else None.""" scope = self.scope_of_node.get(name_node) if scope is None: return None - resolved = scope[name_node.value] # LEGB-resolved; empty set if undefined + resolved = scope[name_node.value] # LEGB-resolved if len(resolved) != 1: - return None # rebound / nonlocal-rebind / undefined + return None # rebound / nonlocal assignment = next(iter(resolved)) if type(assignment) is not Assignment: - return None # excludes ImportAssignment and BuiltinAssignment + return None if not isinstance(assignment.scope, FunctionScope): - return None # function locals only (excludes global / class / comprehension) + return None + + in_scope = assignment.scope is scope if isinstance(assignment.node, cst.Param): - return None # parameters, not handled for now - if assignment.scope is not scope: - return None # closure, not handled for now - return self._rhs_for_target(assignment.node) + return Binding(in_scope, None) + + rhs = self._rhs_for_target(assignment.node) + return Binding(in_scope, rhs) if rhs is not None else None def _rhs_for_target(self, target: cst.CSTNode) -> cst.BaseExpression | None: match self.parent_of.get(target): - case cst.AssignTarget(target=cst.Name()) as target: # `x = v` - return cast(cst.Assign, self.parent_of.get(target)).value + case cst.AssignTarget(target=cst.Name()) as assign_target: # `x = v` + return cast(cst.Assign, self.parent_of.get(assign_target)).value case cst.AnnAssign(value=value): # `x: T = v` return value case cst.NamedExpr(value=value): # walrus `(x := v)` @@ -309,10 +318,13 @@ def classify( self, inputs: tuple[Any, ...], raw_kwargs: dict[str, Any] ) -> tuple[list[CompileRef | Any], dict[str, CompileRef | Any]] | None: call = self.module_info.call_at(self.frame) - if call is None or any(a.star for a in call.args): - return None # no call node, or caller-side *args/**kwargs - pos_nodes = [a.value for a in call.args if a.keyword is None] - kw_nodes = {a.keyword.value: a.value for a in call.args if a.keyword is not None} + if call is None: + return None + + split = _split_call_args(call) + if split is None: + return None + pos_nodes, kw_nodes = split try: classified_inputs: list[CompileRef | Any] = [] @@ -340,9 +352,7 @@ def _capture_arg(self, node: cst.BaseExpression | None, value: Any) -> CompileRe def is_invariant(self, node: cst.BaseExpression) -> bool: match node: - case cst.BaseNumber() | cst.SimpleString(): - return True - case cst.Name(value="True" | "False" | "None"): + case cst.BaseNumber() | cst.SimpleString() | cst.Name(value="True" | "False" | "None"): return True case cst.UnaryOperation(operator=cst.Minus() | cst.Plus(), expression=x): return self.is_invariant(x) @@ -361,8 +371,8 @@ def is_invariant(self, node: cst.BaseExpression) -> bool: return False def _is_name_invariant(self, name_node: cst.Name) -> bool: - rhs = self.module_info.local_rhs(name_node) - if rhs is None or not self.is_invariant(rhs): + binding = self.module_info.binding(name_node) + if binding is None or not self._is_binding_invariant(binding, name_node): return False # A named mutable is a live handle the user can alias and mutate. # It's hard to prove that they are invariant. @@ -371,6 +381,81 @@ def _is_name_invariant(self, name_node: cst.Name) -> bool: except _Unresolved: return False + def _is_binding_invariant(self, binding: Binding, name_node: cst.Name) -> bool: + """True if `name_node`'s binding is invariant (captured name re-roots at live owner).""" + if binding.in_scope: + classifier, frame = self, self.frame + elif frame := self._live_owner_frame(name_node.value): + classifier = _Classifier(self.module_info, frame) + else: + return True # owner returned: frozen cell + + if binding.rhs is None: + return classifier._is_param_invariant(name_node, frame) + return classifier.is_invariant(binding.rhs) + + def _live_owner_frame(self, name: str) -> types.FrameType | None: + """Find the live frame owning a closure cell""" + frame = self.frame.f_back + while frame is not None: + if name in frame.f_code.co_cellvars: + return frame + frame = frame.f_back + return None + + def _is_param_invariant(self, name_node: cst.Name, owner_frame: types.FrameType) -> bool: + """True if parameter `name_node` of `owner_frame` was passed an invariant argument.""" + caller = resolve_callsite_frame(owner_frame.f_back) + if caller is None: + return False + + mi = _get_module_info(caller.f_code.co_filename) # caller may be in another module + if mi is None: + return False + + call = mi.call_at(caller) + if call is None: + return False + + return _Classifier(mi, caller)._is_arg_invariant(call, name_node.value, owner_frame.f_code) + + def _is_arg_invariant( + self, call: cst.Call, param_name: str, callee_code: types.CodeType + ) -> bool: + """True if `call` binds `param_name` of `callee_code` to an invariant argument.""" + split = _split_call_args(call) + if split is None: + return False + pos_nodes, kw_nodes = split + + try: + callable_obj = _safe_resolve(call.func, self.frame) + except _Unresolved: + return False + if not _matches_callee(callable_obj, callee_code): + return False + + try: + sig = inspect.signature(callable_obj, follow_wrapped=False) + except (ValueError, TypeError): + return False + + param = sig.parameters.get(param_name) + if param is None or param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + return False + + try: + bound = sig.bind(*pos_nodes, **kw_nodes) + except TypeError: + return False + + if param_name not in bound.arguments: + return param.default is not inspect.Parameter.empty # omitted: frozen default + return self.is_invariant(bound.arguments[param_name]) + def _is_dali_chain(self, node: cst.Attribute) -> bool: """The only supported exceptions for attributes are those anchored in nvidia.dali or a DALI enum. @@ -395,6 +480,26 @@ def _is_dali_chain(self, node: cst.Attribute) -> bool: return anchored and is_dali_constant(value) +def _split_call_args( + call: cst.Call, +) -> tuple[list[cst.BaseExpression], dict[str, cst.BaseExpression]] | None: + """Extract positional and keyword arguments from a call expression.""" + if any(a.star for a in call.args): + return None + pos = [a.value for a in call.args if a.keyword is None] + kw = {a.keyword.value: a.value for a in call.args if a.keyword is not None} + return pos, kw + + +def _matches_callee(obj: Any, callee_code: types.CodeType) -> bool: + """Check that `obj` actually matches the function we're expecting to be in""" + if isinstance(obj, types.MethodType): + return _matches_callee(obj.__func__, callee_code) + if isinstance(obj, functools.partial): + return _matches_callee(obj.func, callee_code) + return isinstance(obj, types.FunctionType) and obj.__code__ is callee_code + + def classify( frame: types.FrameType, inputs: tuple[Any, ...], raw_kwargs: dict[str, Any] ) -> tuple[list[CompileRef | Any], dict[str, CompileRef | Any]] | None: diff --git a/dali/test/python/experimental_mode/test_compile_invariants.py b/dali/test/python/experimental_mode/test_compile_invariants.py index d9cb8cbae90..5cb6769f1b8 100644 --- a/dali/test/python/experimental_mode/test_compile_invariants.py +++ b/dali/test/python/experimental_mode/test_compile_invariants.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib.util +import pathlib +import tempfile import functools import os from collections.abc import Callable @@ -63,8 +66,6 @@ def test(): return decorator -# Module-level fixtures for the rejection tests -# _MODULE_DTYPE = ndd.float32 _GLOBAL_ANGLE = 60 @@ -223,6 +224,171 @@ def transform(images): compiled_test(expect_captured=True)(transform)() +@compiled_test(expect_captured=True) +def test_param_literal(images): + def rotate(imgs, angle): + return ndd.rotate(imgs, angle=angle) + + return rotate(images, 60) + + +@compiled_test(expect_captured=True) +def test_param_expression(images): + def rotate(imgs, angle, /): + return ndd.rotate(imgs, angle=angle) + + return rotate(images, 40 + 20) + + +@compiled_test(expect_captured=True) +def test_param_dali_attribute(images): + def cast(imgs, dtype): + return ndd.cast(imgs, dtype=dtype) + + return cast(images, ndd.int32) + + +@compiled_test(expect_captured=True) +def test_param_chained(images): + def inner(imgs, angle): + return ndd.rotate(imgs, angle=angle) + + def outer(imgs, angle): + return inner(imgs, angle) + + return outer(images, 60) + + +@compiled_test(expect_captured=True) +def test_param_recursive(images): + def _rotate_recursive(images, angle, depth): + if depth == 0: + return ndd.rotate(images, angle=angle) + return _rotate_recursive(images, angle, depth - 1) + + return _rotate_recursive(images, 60, 4) + + +@compiled_test(expect_captured=True) +def test_param_cross_file(images): + source = """ +import nvidia.dali.experimental.dynamic as ndd + +def resize(*, images, size): + return ndd.resize(images, size=size) +""" + with tempfile.TemporaryDirectory() as tmpdir: + module_path = pathlib.Path(tmpdir) / "module.py" + module_path.write_text(source) + + # Import the module + spec = importlib.util.spec_from_file_location("module", module_path) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + return module.resize(images=images, size=(224, 224)) + + +@compiled_test(expect_captured=True) +def test_param_method(images): + class Aug: + def rotate(self, imgs, angle): + return ndd.rotate(imgs, angle=angle) + + aug = Aug() + return aug.rotate(images, 60) + + +@compiled_test(expect_captured=True) +def test_param_classmethod(images): + class Aug: + @classmethod + def rotate(cls, imgs, angle): + return ndd.rotate(imgs, angle=angle) + + return Aug.rotate(images, 60) + + +@compiled_test(expect_captured=True) +def test_param_staticmethod(images): + class Aug: + @staticmethod + def rotate(imgs, angle): + return ndd.rotate(imgs, angle=angle) + + return Aug.rotate(images, 60) + + +@compiled_test(expect_captured=True) +def test_param_partial_keyword(images): + def resize(imgs, width, height): + return ndd.resize(imgs, size=[width, height]) + + resize_partial = functools.partial(resize, width=64) + return resize_partial(images, height=128) + + +@compiled_test(expect_captured=True) +def test_closure_param(images): + def make_rotate(angle): + def rotate(): + return ndd.rotate(images, angle=angle) + + return rotate + + return make_rotate(60)() + + +@compiled_test(expect_captured=True) +def test_closure_local(images): + def make_rotate(): + def rotate(): + return ndd.rotate(images, angle=angle) + + angle = 60 + return rotate + + return make_rotate()() + + +@compiled_test(expect_captured=True) +def test_closure_live_parent(images): + def rotate(angle): + def transform(): + return ndd.rotate(images, angle=angle) + + return transform() + + return rotate(60) + + +@compiled_test(expect_captured=True) +def test_param_default_literal(images): + def rotate(imgs, angle=60): + return ndd.rotate(imgs, angle=angle) + + return rotate(images) + + +@compiled_test(expect_captured=True) +def test_param_default_dali_attribute(images): + def cast(imgs, dtype=ndd.int32): + return ndd.cast(imgs, dtype=dtype) + + return cast(images) + + +@compiled_test(expect_captured=True) +def test_param_default_name(images): + angle = 60 + + def rotate(imgs, angle=angle): + return ndd.rotate(imgs, angle=angle) + + return rotate(images) + + # Tests for rejected cases @@ -313,3 +479,100 @@ def test_import_name(images): from math import pi return ndd.rotate(images, angle=pi) + + +@compiled_test(expect_captured=False) +def test_param_mutable(images): + def resize(imgs, size): + size[1] = 42 + return ndd.resize(imgs, size=size) + + return resize(images, [224, 224]) + + +@compiled_test(expect_captured=False) +def test_param_varargs(images): + def resize(imgs, *size): + return ndd.resize(imgs, size=size) + + return resize(images, 224, 224) + + +@compiled_test(expect_captured=False) +def test_default_mutable(images): + def resize(imgs, size=[224, 224]): # noqa: B006 + return ndd.resize(imgs, size=size) + + return resize(images) + + +@compiled_test(expect_captured=False) +def test_closure_mutable_cell(images): + def make_resize(size): + def resize(): + return ndd.resize(images, size=size) + + return resize + + return make_resize([224, 224])() + + +@compiled_test(expect_captured=False) +def test_closure_nonlocal_rebind(images): + def make_rotate(): + angle = 60 + + def rebind(): + nonlocal angle + angle = 90 + + def rotate(): + return ndd.rotate(images, angle=angle) + + rebind() + return rotate + + return make_rotate()() + + +@compiled_test(expect_captured=False) +def test_param_through_decorator(images): + def rotate(imgs, angle): + return ndd.rotate(imgs, angle=angle) + + @functools.wraps(rotate) + def wrapped(*args, **kwargs): + return rotate(*args, **kwargs) + + return wrapped(images, 60) + + +@compiled_test(expect_captured=False) +def test_param_inline_call_target(images): + def make_rotate(): + def rotate(imgs, angle): + return ndd.rotate(imgs, angle=angle) + + return rotate + + return make_rotate()(images, 60) + + +@compiled_test(expect_captured=False) +def test_param_callable_instance_attr(images): + class Aug: + def __call__(self, imgs, angle): + return ndd.rotate(imgs, angle=angle) + + class Holder: + aug = Aug() + + return Holder.aug(images, 60) + + +@compiled_test(expect_captured=False) +def test_param_global_arg(images): + def augment(imgs, crop): + return ndd.rotate(imgs, angle=crop) + + return augment(images, _GLOBAL_ANGLE)