diff --git a/packages/reflex-base/src/reflex_base/components/memo.py b/packages/reflex-base/src/reflex_base/components/memo.py index 979d1501bd5..477c7431fe9 100644 --- a/packages/reflex-base/src/reflex_base/components/memo.py +++ b/packages/reflex-base/src/reflex_base/components/memo.py @@ -284,6 +284,9 @@ class MemoComponentDefinition(MemoDefinition): export_name: str _component: _LazyBody[Component] + _runtime_param_values: dict[str, Any] = dataclasses.field( + default_factory=dict, repr=False, compare=False + ) # For passthrough wrappers built by the auto-memoize plugin: the # ``Bare``-wrapped ``{children}`` placeholder used when rendering the memo # body. The ``component`` keeps its ORIGINAL children so compile-time @@ -756,16 +759,26 @@ def _rest_placeholder(name: str) -> RestProp: return RestProp(_js_expr=name, _var_type=dict[str, Any]) -def _var_placeholder(name: str, annotation: Any) -> Var: +def _var_placeholder( + name: str, + annotation: Any, + runtime_value: Any | None = None, +) -> Var: """Create a placeholder Var for a memo parameter. Args: name: The JavaScript identifier. annotation: The parameter annotation. + runtime_value: Optional runtime value used to infer unannotated params. Returns: The placeholder Var. """ + if _annotation_inner_type(annotation) is Any and runtime_value is not None: + runtime_type = ( + runtime_value._var_type if isinstance(runtime_value, Var) else type(runtime_value) + ) + return Var(_js_expr=name, _var_type=runtime_type).guess_type() return Var(_js_expr=name, _var_type=_annotation_inner_type(annotation)).guess_type() @@ -1033,12 +1046,14 @@ def finalize( def _evaluate_memo_function( fn: Callable[..., Any], params: tuple[MemoParam, ...], + runtime_values: Mapping[str, Any] | None = None, ) -> Any: """Evaluate a memo function with placeholder vars. Args: fn: The function to evaluate. params: The memo parameters. + runtime_values: Optional runtime values keyed by parameter name. Returns: The return value from the function. @@ -1047,7 +1062,14 @@ def _evaluate_memo_function( keyword_args = {} for param in params: - placeholder = param.make_placeholder() + if param.kind is MemoParamKind.VALUE: + placeholder = _var_placeholder( + param.placeholder_name, + param.annotation, + runtime_values.get(param.name) if runtime_values is not None else None, + ) + else: + placeholder = param.make_placeholder() if param.parameter_kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, @@ -1299,13 +1321,16 @@ def _build_args_function( def _evaluate_component_body( - fn: Callable[..., Any], params: tuple[MemoParam, ...] + fn: Callable[..., Any], + params: tuple[MemoParam, ...], + runtime_values: Mapping[str, Any] | None = None, ) -> Component: """Run a component memo's body and return its compiled component. Args: fn: The decorated function. params: The analyzed memo parameters. + runtime_values: Optional runtime values keyed by parameter name. Returns: The wrapped component the body returned. @@ -1313,7 +1338,9 @@ def _evaluate_component_body( Raises: TypeError: If the body does not return a component. """ - body = _normalize_component_return(_evaluate_memo_function(fn, params)) + body = _normalize_component_return( + _evaluate_memo_function(fn, params, runtime_values) + ) if body is None: msg = ( f"Component-returning `@rx.memo` `{fn.__name__}` must return an " @@ -1359,13 +1386,17 @@ def _create_component_definition( TypeError: If the function does not return a component. """ params = _analyze_params(fn, for_component=True) + runtime_param_values: dict[str, Any] = {} return MemoComponentDefinition( fn=fn, python_name=fn.__name__, params=params, source_module=source_module, export_name=format.to_title_case(fn.__name__), - _component=_LazyBody.ready(_evaluate_component_body(fn, params)), + _component=_LazyBody( + lambda: _evaluate_component_body(fn, params, runtime_param_values) + ), + _runtime_param_values=runtime_param_values, ) @@ -1628,9 +1659,15 @@ def __call__(self, *children: Any, **props: Any) -> MemoComponent: # Reading ``component`` materializes the deferred body, so ``type(...)`` # reflects the real wrapped class rather than the placeholder. + definition._runtime_param_values.clear() + definition._runtime_param_values.update(explicit_values) + try: + component_type = type(definition.component) + finally: + definition._runtime_param_values.clear() return _get_memo_component_class( definition.export_name, - type(definition.component), + component_type, definition.source_module, )._create( children=list(children), @@ -1925,6 +1962,7 @@ def memo(fn: Callable[..., Any]) -> _MemoComponentWrapper | _MemoFunctionWrapper # where the name resolves to ``wrapper`` (already bound by first use). definition: MemoComponentDefinition | MemoFunctionDefinition if is_component: + runtime_param_values: dict[str, Any] = {} definition = MemoComponentDefinition( fn=fn, python_name=fn.__name__, @@ -1932,9 +1970,10 @@ def memo(fn: Callable[..., Any]) -> _MemoComponentWrapper | _MemoFunctionWrapper source_module=source_module, export_name=format.to_title_case(fn.__name__), _component=_LazyBody( - lambda: _evaluate_component_body(fn, params), + lambda: _evaluate_component_body(fn, params, runtime_param_values), placeholder=Fragment.create(), ), + _runtime_param_values=runtime_param_values, ) wrapper = _create_component_wrapper(definition) else: diff --git a/tests/units/components/test_memo.py b/tests/units/components/test_memo.py index 9fa38e51801..625f024e9b5 100644 --- a/tests/units/components/test_memo.py +++ b/tests/units/components/test_memo.py @@ -534,6 +534,42 @@ def soft_missing(value) -> rx.Component: assert "`value`" in kwargs["reason"] +def test_memo_uses_first_call_value_type_for_missing_param_annotation(): + """Component memos should infer missing parameter types from the first call.""" + + @rx.memo + def user_card(user) -> rx.Component: + return rx.box( + rx.heading(user["name"]), + rx.text(user["email"]), + ) + + component = user_card( + user={"name": "Ada", "email": "ada@example.com"}, + ) + + assert isinstance(component, MemoComponent) + + +def test_memo_uses_var_runtime_value_type_for_missing_param_annotation(): + """Component memos should infer missing parameter types from runtime Vars.""" + + @rx.memo + def user_card(user) -> rx.Component: + return rx.box( + rx.heading(user["name"]), + rx.text(user["email"]), + ) + + component = user_card( + user=Var(_js_expr="user", _var_type=dict), + ) + + assert isinstance(component, MemoComponent) + assert isinstance(component.user, Var) + assert component.user._var_type is dict + + def test_memo_warns_on_missing_return_annotation(): """A missing return annotation should default to ``rx.Component`` with a warning.""" with patch.object(console, "deprecate") as mock_deprecate: