-
Notifications
You must be signed in to change notification settings - Fork 667
Accept function parameters and closure cells as arguments with compile=True #6388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| # limitations under the License. | ||
|
|
||
| import ast | ||
| import functools | ||
| import inspect | ||
| import itertools | ||
| import linecache | ||
|
|
@@ -36,7 +37,7 @@ | |
| ) | ||
| from nvidia.dali.types import DALIDataType, DALIImageType, DALIInterpType | ||
|
|
||
| from ._call_site import CodeLoc | ||
| from ._call_site import CodeLoc, resolve_callsite_frame | ||
| from ._compile import CompiledBatch, CompileRef | ||
| from ._device import Device | ||
| from ._type import DType | ||
|
|
@@ -168,6 +169,12 @@ def _unpack_bindings_impl(lhs: Sequence[cst.BaseElement], rhs: Sequence[cst.Base | |
| return bindings | ||
|
|
||
|
|
||
| @dataclass(frozen=True, slots=True) | ||
| class Binding: | ||
| in_scope: bool | ||
| rhs: cst.BaseExpression | None # None if binding is parameter | ||
|
|
||
|
|
||
| @dataclass(frozen=True, slots=True) | ||
| class ModuleInfo: | ||
| """Per-file parsed libcst data plus the queries classification needs over it.""" | ||
|
|
@@ -202,31 +209,33 @@ def _resolve_call(self, frame: types.FrameType) -> cst.Call | None: | |
| candidates = self.calls_by_line.get(frame.f_lineno, ()) | ||
| return candidates[0] if len(candidates) == 1 else None | ||
|
|
||
| def local_rhs(self, name_node: cst.Name) -> cst.BaseExpression | None: | ||
| """The RHS of `name_node`'s single capturable function-local binding, else None.""" | ||
| def binding(self, name_node: cst.Name) -> Binding | None: | ||
| """Return the single function-scope binding of `name_node`, else None.""" | ||
| scope = self.scope_of_node.get(name_node) | ||
| if scope is None: | ||
| return None | ||
| resolved = scope[name_node.value] # LEGB-resolved; empty set if undefined | ||
| resolved = scope[name_node.value] # LEGB-resolved | ||
| if len(resolved) != 1: | ||
| return None # rebound / nonlocal-rebind / undefined | ||
| return None # rebound / nonlocal | ||
|
|
||
| assignment = next(iter(resolved)) | ||
| if type(assignment) is not Assignment: | ||
| return None # excludes ImportAssignment and BuiltinAssignment | ||
| return None | ||
|
|
||
| if not isinstance(assignment.scope, FunctionScope): | ||
| return None # function locals only (excludes global / class / comprehension) | ||
| return None | ||
|
|
||
| in_scope = assignment.scope is scope | ||
| if isinstance(assignment.node, cst.Param): | ||
| return None # parameters, not handled for now | ||
| if assignment.scope is not scope: | ||
| return None # closure, not handled for now | ||
| return self._rhs_for_target(assignment.node) | ||
| return Binding(in_scope, None) | ||
|
|
||
| rhs = self._rhs_for_target(assignment.node) | ||
| return Binding(in_scope, rhs) if rhs is not None else None | ||
|
|
||
| def _rhs_for_target(self, target: cst.CSTNode) -> cst.BaseExpression | None: | ||
| match self.parent_of.get(target): | ||
| case cst.AssignTarget(target=cst.Name()) as target: # `x = v` | ||
| return cast(cst.Assign, self.parent_of.get(target)).value | ||
| case cst.AssignTarget(target=cst.Name()) as assign_target: # `x = v` | ||
| return cast(cst.Assign, self.parent_of.get(assign_target)).value | ||
| case cst.AnnAssign(value=value): # `x: T = v` | ||
| return value | ||
| case cst.NamedExpr(value=value): # walrus `(x := v)` | ||
|
|
@@ -309,10 +318,13 @@ def classify( | |
| self, inputs: tuple[Any, ...], raw_kwargs: dict[str, Any] | ||
| ) -> tuple[list[CompileRef | Any], dict[str, CompileRef | Any]] | None: | ||
| call = self.module_info.call_at(self.frame) | ||
| if call is None or any(a.star for a in call.args): | ||
| return None # no call node, or caller-side *args/**kwargs | ||
| pos_nodes = [a.value for a in call.args if a.keyword is None] | ||
| kw_nodes = {a.keyword.value: a.value for a in call.args if a.keyword is not None} | ||
| if call is None: | ||
| return None | ||
|
|
||
| split = _split_call_args(call) | ||
| if split is None: | ||
| return None | ||
| pos_nodes, kw_nodes = split | ||
|
|
||
| try: | ||
| classified_inputs: list[CompileRef | Any] = [] | ||
|
|
@@ -340,9 +352,7 @@ def _capture_arg(self, node: cst.BaseExpression | None, value: Any) -> CompileRe | |
|
|
||
| def is_invariant(self, node: cst.BaseExpression) -> bool: | ||
| match node: | ||
| case cst.BaseNumber() | cst.SimpleString(): | ||
| return True | ||
| case cst.Name(value="True" | "False" | "None"): | ||
| case cst.BaseNumber() | cst.SimpleString() | cst.Name(value="True" | "False" | "None"): | ||
| return True | ||
| case cst.UnaryOperation(operator=cst.Minus() | cst.Plus(), expression=x): | ||
| return self.is_invariant(x) | ||
|
|
@@ -361,8 +371,8 @@ def is_invariant(self, node: cst.BaseExpression) -> bool: | |
| return False | ||
|
|
||
| def _is_name_invariant(self, name_node: cst.Name) -> bool: | ||
| rhs = self.module_info.local_rhs(name_node) | ||
| if rhs is None or not self.is_invariant(rhs): | ||
| binding = self.module_info.binding(name_node) | ||
| if binding is None or not self._is_binding_invariant(binding, name_node): | ||
| return False | ||
| # A named mutable is a live handle the user can alias and mutate. | ||
| # It's hard to prove that they are invariant. | ||
|
|
@@ -371,6 +381,81 @@ def _is_name_invariant(self, name_node: cst.Name) -> bool: | |
| except _Unresolved: | ||
| return False | ||
|
|
||
| def _is_binding_invariant(self, binding: Binding, name_node: cst.Name) -> bool: | ||
| """True if `name_node`'s binding is invariant (captured name re-roots at live owner).""" | ||
| if binding.in_scope: | ||
| classifier, frame = self, self.frame | ||
| elif frame := self._live_owner_frame(name_node.value): | ||
| classifier = _Classifier(self.module_info, frame) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Minor / Design note]
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would you suggest?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe just: The contrast with _is_param_invariant is that call arguments can come from a different caller module, while a live |
||
| else: | ||
| return True # owner returned: frozen cell | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Minor] The comment
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it's necessary to re-specify this for this specific case. Immutability is checked by the caller for all names, not just cells of closures with dead parents. |
||
|
|
||
| if binding.rhs is None: | ||
| return classifier._is_param_invariant(name_node, frame) | ||
| return classifier.is_invariant(binding.rhs) | ||
|
|
||
| def _live_owner_frame(self, name: str) -> types.FrameType | None: | ||
| """Find the live frame owning a closure cell""" | ||
| frame = self.frame.f_back | ||
| while frame is not None: | ||
| if name in frame.f_code.co_cellvars: | ||
| return frame | ||
| frame = frame.f_back | ||
| return None | ||
|
|
||
| def _is_param_invariant(self, name_node: cst.Name, owner_frame: types.FrameType) -> bool: | ||
| """True if parameter `name_node` of `owner_frame` was passed an invariant argument.""" | ||
| caller = resolve_callsite_frame(owner_frame.f_back) | ||
| if caller is None: | ||
| return False | ||
|
rostan-t marked this conversation as resolved.
|
||
|
|
||
| mi = _get_module_info(caller.f_code.co_filename) # caller may be in another module | ||
| if mi is None: | ||
| return False | ||
|
|
||
| call = mi.call_at(caller) | ||
| if call is None: | ||
| return False | ||
|
|
||
| return _Classifier(mi, caller)._is_arg_invariant(call, name_node.value, owner_frame.f_code) | ||
|
|
||
| def _is_arg_invariant( | ||
| self, call: cst.Call, param_name: str, callee_code: types.CodeType | ||
| ) -> bool: | ||
| """True if `call` binds `param_name` of `callee_code` to an invariant argument.""" | ||
| split = _split_call_args(call) | ||
| if split is None: | ||
| return False | ||
| pos_nodes, kw_nodes = split | ||
|
|
||
| try: | ||
| callable_obj = _safe_resolve(call.func, self.frame) | ||
| except _Unresolved: | ||
| return False | ||
| if not _matches_callee(callable_obj, callee_code): | ||
| return False | ||
|
|
||
| try: | ||
| sig = inspect.signature(callable_obj, follow_wrapped=False) | ||
| except (ValueError, TypeError): | ||
| return False | ||
|
|
||
| param = sig.parameters.get(param_name) | ||
| if param is None or param.kind in ( | ||
| inspect.Parameter.VAR_POSITIONAL, | ||
| inspect.Parameter.VAR_KEYWORD, | ||
| ): | ||
| return False | ||
|
|
||
| try: | ||
| bound = sig.bind(*pos_nodes, **kw_nodes) | ||
| except TypeError: | ||
| return False | ||
|
|
||
| if param_name not in bound.arguments: | ||
| return param.default is not inspect.Parameter.empty # omitted: frozen default | ||
| return self.is_invariant(bound.arguments[param_name]) | ||
|
|
||
| def _is_dali_chain(self, node: cst.Attribute) -> bool: | ||
| """The only supported exceptions for attributes are those | ||
| anchored in nvidia.dali or a DALI enum. | ||
|
|
@@ -395,6 +480,26 @@ def _is_dali_chain(self, node: cst.Attribute) -> bool: | |
| return anchored and is_dali_constant(value) | ||
|
|
||
|
|
||
| def _split_call_args( | ||
| call: cst.Call, | ||
| ) -> tuple[list[cst.BaseExpression], dict[str, cst.BaseExpression]] | None: | ||
| """Extract positional and keyword arguments from a call expression.""" | ||
| if any(a.star for a in call.args): | ||
| return None | ||
| pos = [a.value for a in call.args if a.keyword is None] | ||
| kw = {a.keyword.value: a.value for a in call.args if a.keyword is not None} | ||
| return pos, kw | ||
|
|
||
|
|
||
| def _matches_callee(obj: Any, callee_code: types.CodeType) -> bool: | ||
| """Check that `obj` actually matches the function we're expecting to be in""" | ||
| if isinstance(obj, types.MethodType): | ||
| return _matches_callee(obj.__func__, callee_code) | ||
| if isinstance(obj, functools.partial): | ||
| return _matches_callee(obj.func, callee_code) | ||
| return isinstance(obj, types.FunctionType) and obj.__code__ is callee_code | ||
|
|
||
|
|
||
| def classify( | ||
| frame: types.FrameType, inputs: tuple[Any, ...], raw_kwargs: dict[str, Any] | ||
| ) -> tuple[list[CompileRef | Any], dict[str, CompileRef | Any]] | None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.