Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dali/python/nvidia/dali/experimental/dynamic/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,11 @@ def _matches(self, actual: Any, expected: Any) -> bool:
)
if expected is None:
return actual is None
return not isinstance(actual, Batch) and actual == expected
if isinstance(actual, Batch):
return False

result = actual == expected
return result if isinstance(result, bool) else np.all(result).item()

@_nvtx_range("Getting compiled result")
def get_compiled_result(
Expand Down
149 changes: 127 additions & 22 deletions dali/python/nvidia/dali/experimental/dynamic/_source_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import ast
import functools
import inspect
import itertools
import linecache
Expand All @@ -36,7 +37,7 @@
)
from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType

from ._call_site import CodeLoc
from ._call_site import CodeLoc, resolve_callsite_frame
from ._compile import CompiledBatch, CompileRef
from ._device import Device
from ._type import DType
Expand Down Expand Up @@ -168,6 +169,12 @@ def _unpack_bindings_impl(lhs: Sequence[cst.BaseElement], rhs: Sequence[cst.Base
return bindings


@dataclass(frozen=True, slots=True)
class Binding:
in_scope: bool
rhs: cst.BaseExpression | None # None if binding is parameter


@dataclass(frozen=True, slots=True)
class ModuleInfo:
"""Per-file parsed libcst data plus the queries classification needs over it."""
Expand Down Expand Up @@ -202,31 +209,33 @@ def _resolve_call(self, frame: types.FrameType) -> cst.Call | None:
candidates = self.calls_by_line.get(frame.f_lineno, ())
return candidates[0] if len(candidates) == 1 else None

def local_rhs(self, name_node: cst.Name) -> cst.BaseExpression | None:
"""The RHS of `name_node`'s single capturable function-local binding, else None."""
def binding(self, name_node: cst.Name) -> Binding | None:
"""Return the single function-scope binding of `name_node`, else None."""
scope = self.scope_of_node.get(name_node)
if scope is None:
return None
resolved = scope[name_node.value] # LEGB-resolved; empty set if undefined
resolved = scope[name_node.value] # LEGB-resolved
if len(resolved) != 1:
return None # rebound / nonlocal-rebind / undefined
return None # rebound / nonlocal

assignment = next(iter(resolved))
if type(assignment) is not Assignment:
return None # excludes ImportAssignment and BuiltinAssignment
return None

if not isinstance(assignment.scope, FunctionScope):
return None # function locals only (excludes global / class / comprehension)
return None

in_scope = assignment.scope is scope
if isinstance(assignment.node, cst.Param):
return None # parameters, not handled for now
if assignment.scope is not scope:
return None # closure, not handled for now
return self._rhs_for_target(assignment.node)
return Binding(in_scope, None)

rhs = self._rhs_for_target(assignment.node)
return Binding(in_scope, rhs) if rhs is not None else None

def _rhs_for_target(self, target: cst.CSTNode) -> cst.BaseExpression | None:
match self.parent_of.get(target):
case cst.AssignTarget(target=cst.Name()) as target: # `x = v`
return cast(cst.Assign, self.parent_of.get(target)).value
case cst.AssignTarget(target=cst.Name()) as assign_target: # `x = v`
return cast(cst.Assign, self.parent_of.get(assign_target)).value
case cst.AnnAssign(value=value): # `x: T = v`
return value
case cst.NamedExpr(value=value): # walrus `(x := v)`
Expand Down Expand Up @@ -309,10 +318,13 @@ def classify(
self, inputs: tuple[Any, ...], raw_kwargs: dict[str, Any]
) -> tuple[list[CompileRef | Any], dict[str, CompileRef | Any]] | None:
call = self.module_info.call_at(self.frame)
if call is None or any(a.star for a in call.args):
return None # no call node, or caller-side *args/**kwargs
pos_nodes = [a.value for a in call.args if a.keyword is None]
kw_nodes = {a.keyword.value: a.value for a in call.args if a.keyword is not None}
if call is None:
return None

split = _split_call_args(call)
if split is None:
return None
pos_nodes, kw_nodes = split

try:
classified_inputs: list[CompileRef | Any] = []
Expand Down Expand Up @@ -340,9 +352,7 @@ def _capture_arg(self, node: cst.BaseExpression | None, value: Any) -> CompileRe

def is_invariant(self, node: cst.BaseExpression) -> bool:
match node:
case cst.BaseNumber() | cst.SimpleString():
return True
case cst.Name(value="True" | "False" | "None"):
case cst.BaseNumber() | cst.SimpleString() | cst.Name(value="True" | "False" | "None"):
Comment thread
rostan-t marked this conversation as resolved.
Dismissed
return True
case cst.UnaryOperation(operator=cst.Minus() | cst.Plus(), expression=x):
return self.is_invariant(x)
Expand All @@ -361,8 +371,8 @@ def is_invariant(self, node: cst.BaseExpression) -> bool:
return False

def _is_name_invariant(self, name_node: cst.Name) -> bool:
rhs = self.module_info.local_rhs(name_node)
if rhs is None or not self.is_invariant(rhs):
binding = self.module_info.binding(name_node)
if binding is None or not self._is_binding_invariant(binding, name_node):
return False
# A named mutable is a live handle the user can alias and mutate.
# It's hard to prove that they are invariant.
Expand All @@ -371,6 +381,81 @@ def _is_name_invariant(self, name_node: cst.Name) -> bool:
except _Unresolved:
return False

def _is_binding_invariant(self, binding: Binding, name_node: cst.Name) -> bool:
"""True if `name_node`'s binding is invariant (captured name re-roots at live owner)."""
if binding.in_scope:
classifier, frame = self, self.frame
elif frame := self._live_owner_frame(name_node.value):
classifier = _Classifier(self.module_info, frame)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor / Design note] self.module_info is reused here as the CST analysis for the owner's frame. This is correct because Python closures are always lexically nested in the same source file, so the owner frame's variable bindings are covered by the same ModuleInfo. Worth a brief comment clarifying this assumption for readers who might wonder why we don't do _get_module_info(frame.f_code.co_filename) here like _is_param_invariant does.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you suggest?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just:

          elif frame := self._live_owner_frame(name_node.value):
              # A live closure owner is a lexical parent, so it shares this module's CST metadata.
              classifier = _Classifier(self.module_info, frame)

The contrast with _is_param_invariant is that call arguments can come from a different caller module, while a live
closure owner is lexically enclosing this function, so self.module_info is the right one.

else:
return True # owner returned: frozen cell

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] The comment # owner returned: frozen cell is accurate but doesn't tell the full story: returning True here only says the binding is stable (the cell can't be rebound once the owner has returned); the immutability of the value is still checked later in _is_name_invariant via _is_immutable_value(_static_eval(...)). This two-phase separation is the right design, but a one-liner like # cell is frozen; value immutability is checked by the caller would make the invariant clearer.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary to re-specify this for this specific case. Immutability is checked by the caller for all names, not just cells of closures with dead parents.


if binding.rhs is None:
return classifier._is_param_invariant(name_node, frame)
return classifier.is_invariant(binding.rhs)

def _live_owner_frame(self, name: str) -> types.FrameType | None:
"""Find the live frame owning a closure cell"""
frame = self.frame.f_back
while frame is not None:
if name in frame.f_code.co_cellvars:
return frame
frame = frame.f_back
return None

def _is_param_invariant(self, name_node: cst.Name, owner_frame: types.FrameType) -> bool:
"""True if parameter `name_node` of `owner_frame` was passed an invariant argument."""
caller = resolve_callsite_frame(owner_frame.f_back)
if caller is None:
return False
Comment thread
rostan-t marked this conversation as resolved.

mi = _get_module_info(caller.f_code.co_filename) # caller may be in another module
if mi is None:
return False

call = mi.call_at(caller)
if call is None:
return False

return _Classifier(mi, caller)._is_arg_invariant(call, name_node.value, owner_frame.f_code)

def _is_arg_invariant(
self, call: cst.Call, param_name: str, callee_code: types.CodeType
) -> bool:
"""True if `call` binds `param_name` of `callee_code` to an invariant argument."""
split = _split_call_args(call)
if split is None:
return False
pos_nodes, kw_nodes = split

try:
callable_obj = _safe_resolve(call.func, self.frame)
except _Unresolved:
return False
if not _matches_callee(callable_obj, callee_code):
return False

try:
sig = inspect.signature(callable_obj, follow_wrapped=False)
except (ValueError, TypeError):
return False

param = sig.parameters.get(param_name)
if param is None or param.kind in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
return False

try:
bound = sig.bind(*pos_nodes, **kw_nodes)
except TypeError:
return False

if param_name not in bound.arguments:
return param.default is not inspect.Parameter.empty # omitted: frozen default
return self.is_invariant(bound.arguments[param_name])

def _is_dali_chain(self, node: cst.Attribute) -> bool:
"""The only supported exceptions for attributes are those
anchored in nvidia.dali or a DALI enum.
Expand All @@ -395,6 +480,26 @@ def _is_dali_chain(self, node: cst.Attribute) -> bool:
return anchored and is_dali_constant(value)


def _split_call_args(
call: cst.Call,
) -> tuple[list[cst.BaseExpression], dict[str, cst.BaseExpression]] | None:
"""Extract positional and keyword arguments from a call expression."""
if any(a.star for a in call.args):
return None
pos = [a.value for a in call.args if a.keyword is None]
kw = {a.keyword.value: a.value for a in call.args if a.keyword is not None}
return pos, kw


def _matches_callee(obj: Any, callee_code: types.CodeType) -> bool:
"""Check that `obj` actually matches the function we're expecting to be in"""
if isinstance(obj, types.MethodType):
return _matches_callee(obj.__func__, callee_code)
if isinstance(obj, functools.partial):
return _matches_callee(obj.func, callee_code)
return isinstance(obj, types.FunctionType) and obj.__code__ is callee_code


def classify(
frame: types.FrameType, inputs: tuple[Any, ...], raw_kwargs: dict[str, Any]
) -> tuple[list[CompileRef | Any], dict[str, CompileRef | Any]] | None:
Expand Down
Loading