From 058beffc2711e1a364c9afd78b83b9055c5b63ca Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 5 Sep 2025 10:31:18 -0700 Subject: [PATCH 1/2] PropsBase converts EventHandler-annotated props to EventChain --- reflex/components/component.py | 14 ++--------- reflex/components/props.py | 15 ++++++++++++ reflex/event.py | 27 +++++++++++++++++++++ tests/units/components/test_props.py | 36 ++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 12 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 383ea75b560..9215cd79542 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -43,6 +43,7 @@ EventChain, EventHandler, EventSpec, + args_specs_from_fields, no_args_event_spec, parse_args_spec, pointer_event_spec, @@ -909,18 +910,7 @@ def get_event_triggers(cls) -> dict[str, types.ArgsSpec | Sequence[types.ArgsSpe """ # Look for component specific triggers, # e.g. variable declared as EventHandler types. - return DEFAULT_TRIGGERS | { - name: ( - metadata[0] - if ( - (metadata := getattr(field.annotated_type, "__metadata__", None)) - is not None - ) - else no_args_event_spec - ) - for name, field in cls.get_fields().items() - if field.type_origin is EventHandler - } # pyright: ignore [reportOperatorIssue] + return DEFAULT_TRIGGERS | args_specs_from_fields(cls.get_fields()) # pyright: ignore [reportOperatorIssue] def __repr__(self) -> str: """Represent the component in React. diff --git a/reflex/components/props.py b/reflex/components/props.py index 285edd50479..968092bbf97 100644 --- a/reflex/components/props.py +++ b/reflex/components/props.py @@ -9,6 +9,7 @@ from typing_extensions import dataclass_transform from reflex.components.field import BaseField, FieldBasedMeta +from reflex.event import EventChain, args_specs_from_fields from reflex.utils import format from reflex.utils.exceptions import InvalidPropValueError from reflex.utils.serializers import serializer @@ -267,6 +268,20 @@ def __init__(self, **kwargs): setattr(self, field_name, field.default_factory()) # Note: Fields with no default and no factory remain unset (required fields) + # Convert EventHandler to EventChain + args_specs = args_specs_from_fields(self.get_fields()) + for handler_name, args_spec in args_specs.items(): + if (handler := getattr(self, handler_name, None)) is not None: + setattr( + self, + handler_name, + EventChain.create( + value=handler, + args_spec=args_spec, + key=handler_name, + ), + ) + @classmethod def get_fields(cls) -> dict[str, Any]: """Get the fields of the object. diff --git a/reflex/event.py b/reflex/event.py index d2b37b024f3..d3d706f8037 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -25,6 +25,7 @@ from typing_extensions import Self, TypeAliasType, TypedDict, TypeVarTuple, Unpack from reflex import constants +from reflex.components.field import BaseField from reflex.constants.compiler import CompileVars, Hooks, Imports from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import format @@ -1654,6 +1655,31 @@ def parse_args_spec(arg_spec: ArgsSpec | Sequence[ArgsSpec]): ), annotations +def args_specs_from_fields( + fields_dict: Mapping[str, BaseField], +) -> dict[str, ArgsSpec | Sequence[ArgsSpec]]: + """Get the event triggers and arg specs from the given fields. + + Args: + fields_dict: The fields, keyed by name + + Returns: + The args spec for any field annotated as EventHandler. + """ + return { + name: ( + metadata[0] + if ( + (metadata := getattr(field.annotated_type, "__metadata__", None)) + is not None + ) + else no_args_event_spec + ) + for name, field in fields_dict.items() + if field.type_origin is EventHandler + } + + def check_fn_match_arg_spec( user_func: Callable, user_func_parameters: Mapping[str, inspect.Parameter], @@ -2406,6 +2432,7 @@ def wrapper( check_fn_match_arg_spec = staticmethod(check_fn_match_arg_spec) resolve_annotation = staticmethod(resolve_annotation) parse_args_spec = staticmethod(parse_args_spec) + args_specs_from_fields = staticmethod(args_specs_from_fields) unwrap_var_annotation = staticmethod(unwrap_var_annotation) get_fn_signature = staticmethod(get_fn_signature) diff --git a/tests/units/components/test_props.py b/tests/units/components/test_props.py index 81c37b388cc..f68fc364434 100644 --- a/tests/units/components/test_props.py +++ b/tests/units/components/test_props.py @@ -1,7 +1,17 @@ +from __future__ import annotations + import pytest from pydantic.v1 import ValidationError from reflex.components.props import NoExtrasAllowedProps, PropsBase +from reflex.event import ( + EventChain, + EventHandler, + event, + no_args_event_spec, + passthrough_event_spec, +) +from reflex.state import State from reflex.utils.exceptions import InvalidPropValueError @@ -177,3 +187,29 @@ def test_props_base_dict_conversion(props_class, props_kwargs, expected_dict): props = props_class(**props_kwargs) result = props.dict() assert result == expected_dict + + +class EventProps(PropsBase): + """Test props with event handler fields.""" + + on_click: EventHandler[no_args_event_spec] + not_start_with_on: EventHandler[passthrough_event_spec(str)] + + +def test_event_handler_props(): + class FooState(State): + @event + def handle_click(self): + pass + + @event + def handle_input(self, value: str): + pass + + props = EventProps( + on_click=FooState.handle_click, # pyright: ignore[reportArgumentType] + not_start_with_on=FooState.handle_input, # pyright: ignore[reportArgumentType] + ) + props_dict = props.dict() + assert isinstance(props_dict["onClick"], EventChain) + assert isinstance(props_dict["notStartWithOn"], EventChain) From 16ee80339a9f0706906474e12a078f1d06c0166a Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 11 Sep 2025 17:53:42 -0700 Subject: [PATCH 2/2] Move _resolve_annotations to FieldBasedMeta Allow both Component and PropsBase to resolve annotations from the module namespace. --- reflex/components/component.py | 8 -------- reflex/components/field.py | 6 +++++- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index 993d63a931e..3ff9d177ef7 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -144,14 +144,6 @@ class BaseComponentMeta(FieldBasedMeta, ABCMeta): _fields: Mapping[str, ComponentField] _js_fields: Mapping[str, ComponentField] - @classmethod - def _resolve_annotations( - cls, namespace: dict[str, Any], name: str - ) -> dict[str, Any]: - return types.resolve_annotations( - namespace.get("__annotations__", {}), namespace["__module__"] - ) - @classmethod def _process_annotated_fields( cls, diff --git a/reflex/components/field.py b/reflex/components/field.py index f19dd6889fc..cfe06caf8c4 100644 --- a/reflex/components/field.py +++ b/reflex/components/field.py @@ -6,6 +6,8 @@ from dataclasses import _MISSING_TYPE, MISSING from typing import Annotated, Any, Generic, TypeVar, get_origin +from reflex.utils import types + FIELD_TYPE = TypeVar("FIELD_TYPE") @@ -114,7 +116,9 @@ def _collect_inherited_fields(cls, bases: tuple[type]) -> dict[str, Any]: def _resolve_annotations( cls, namespace: dict[str, Any], name: str ) -> dict[str, Any]: - return namespace.get("__annotations__", {}) + return types.resolve_annotations( + namespace.get("__annotations__", {}), namespace["__module__"] + ) @classmethod def _process_field_overrides(