From d0ce7ffb48ee89e187fcc155e37641831d67b294 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 28 Aug 2025 13:38:02 -0700 Subject: [PATCH 1/3] optimize various functions --- reflex/app.py | 20 ++++++-- reflex/compiler/templates.py | 2 +- reflex/compiler/utils.py | 11 +++-- reflex/components/component.py | 75 ++++++++++++++++------------- reflex/environment.py | 5 +- reflex/istate/data.py | 15 ++++++ reflex/plugins/sitemap.py | 9 ++-- reflex/reflex.py | 10 +++- reflex/state.py | 4 +- reflex/testing.py | 2 +- reflex/utils/console.py | 46 +++++++++++++++++- reflex/utils/prerequisites.py | 4 +- reflex/utils/pyi_generator.py | 4 +- reflex/utils/serializers.py | 8 +-- reflex/utils/types.py | 14 ++---- reflex/vars/base.py | 31 ++++++------ reflex/vars/object.py | 6 ++- tests/units/plugins/test_sitemap.py | 13 +++-- 18 files changed, 179 insertions(+), 100 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index dae1a80fd11..a298d1beef0 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1100,12 +1100,18 @@ def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> No for substate in state.class_subclasses: self._validate_var_dependencies(substate) - def _compile(self, prerender_routes: bool = False, dry_run: bool = False): + def _compile( + self, + prerender_routes: bool = False, + dry_run: bool = False, + use_rich: bool = True, + ): """Compile the app and output it to the pages folder. Args: prerender_routes: Whether to prerender the routes. dry_run: Whether to compile the app without saving it. + use_rich: Whether to use rich progress bars. Raises: ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined. @@ -1171,10 +1177,14 @@ def get_compilation_time() -> str: return # Create a progress bar. - progress = Progress( - *Progress.get_default_columns()[:-1], - MofNCompleteColumn(), - TimeElapsedColumn(), + progress = ( + Progress( + *Progress.get_default_columns()[:-1], + MofNCompleteColumn(), + TimeElapsedColumn(), + ) + if use_rich + else console.PoorProgress() ) # try to be somewhat accurate - but still not 100% diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index ba286a8decc..9712f026b25 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -362,7 +362,7 @@ def context_template( export const DispatchContext = createContext(null); export const StateContexts = {{{state_contexts_str}}}; export const EventLoopContext = createContext(null); -export const clientStorage = {"{}" if client_storage is None else json_dumps(client_storage)} +export const clientStorage = {"{}" if client_storage is None else json.dumps(client_storage)} {state_str} diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index f85a5fe84c0..3a8a33a5a9c 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -220,10 +220,13 @@ def compile_state(state: type[BaseState]) -> dict: def _compile_client_storage_field( field: Field, -) -> tuple[ - type[Cookie] | type[LocalStorage] | type[SessionStorage] | None, - dict[str, Any] | None, -]: +) -> ( + tuple[ + type[Cookie] | type[LocalStorage] | type[SessionStorage], + dict[str, Any], + ] + | tuple[None, None] +): """Compile the given cookie, local_storage or session_storage field. Args: diff --git a/reflex/components/component.py b/reflex/components/component.py index 5db03404668..a448e0aa398 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -419,7 +419,6 @@ def evaluate_style_namespaces(style: ComponentStyle) -> dict: # Map from component to styling. ComponentStyle = dict[str | type[BaseComponent] | Callable | ComponentNamespace, Any] -ComponentChild = types.PrimitiveType | Var | BaseComponent ComponentChildTypes = (*types.PrimitiveTypes, Var, BaseComponent, type(None)) @@ -480,7 +479,21 @@ def _components_from( return () -def _deterministic_hash(value: object) -> int: +def _hash_str(value: str) -> str: + return md5(f'"{value}"'.encode(), usedforsecurity=False).hexdigest() + + +def _hash_sequence(value: Sequence) -> str: + return _hash_str(str([_deterministic_hash(v) for v in value])) + + +def _hash_dict(value: dict) -> str: + return _hash_sequence( + sorted([(k, _deterministic_hash(v)) for k, v in value.items()]) + ) + + +def _deterministic_hash(value: object) -> str: """Hash a rendered dictionary. Args: @@ -492,37 +505,28 @@ def _deterministic_hash(value: object) -> int: Raises: TypeError: If the value is not hashable. """ - if isinstance(value, BaseComponent): - # If the value is a component, hash its rendered code. - rendered_code = value.render() - return _deterministic_hash(rendered_code) - if isinstance(value, Var): - return _deterministic_hash((value._js_expr, value._get_all_var_data())) - if isinstance(value, VarData): - return _deterministic_hash(dataclasses.asdict(value)) - if isinstance(value, dict): - # Sort the dictionary to ensure consistent hashing. - return _deterministic_hash( - tuple(sorted((k, _deterministic_hash(v)) for k, v in value.items())) - ) - if isinstance(value, int): + if value is None: + # Hash None as a special case. + return "None" + if isinstance(value, (int, float, enum.Enum)): # Hash numbers and booleans directly. - return int(value) - if isinstance(value, float): - return _deterministic_hash(str(value)) + return str(value) if isinstance(value, str): - return int(md5(f'"{value}"'.encode()).hexdigest(), 16) + return _hash_str(value) + if isinstance(value, dict): + return _hash_dict(value) if isinstance(value, (tuple, list)): # Hash tuples by hashing each element. - return _deterministic_hash( - "[" + ",".join(map(str, map(_deterministic_hash, value))) + "]" + return _hash_sequence(value) + if isinstance(value, Var): + return _hash_str( + str((value._js_expr, _deterministic_hash(value._get_all_var_data()))) ) - if isinstance(value, enum.Enum): - # Hash enums by their name. - return _deterministic_hash(str(value)) - if value is None: - # Hash None as a special case. - return _deterministic_hash("None") + if isinstance(value, VarData): + return _hash_dict(dataclasses.asdict(value)) + if isinstance(value, BaseComponent): + # If the value is a component, hash its rendered code. + return _hash_dict(value.render()) msg = ( f"Cannot hash value `{value}` of type `{type(value).__name__}`. " @@ -1038,9 +1042,13 @@ def _get_component_prop_names(cls) -> set[str]: name for name in cls.get_fields() if name in cls.get_props() - and types._issubclass( - types.value_inside_optional(types.get_field_type(cls, name)), Component + and isinstance( + field_type := types.value_inside_optional( + types.get_field_type(cls, name) + ), + type, ) + and issubclass(field_type, Component) } def _get_components_in_props(self) -> Sequence[BaseComponent]: @@ -1508,8 +1516,9 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen Returns: A sequence of parent classes that define the method (differently than the base). """ + current_class_method = getattr(cls, method, None) seen_methods = ( - {getattr(Component, method)} if hasattr(Component, method) else set() + {current_class_method} if current_class_method is not None else set() ) clzs: list[type[Component]] = [] for clz in cls.mro(): @@ -1807,7 +1816,7 @@ def extract_var_hooks(hook: Var): # Add the hook code from add_hooks for each parent class (this is reversed to preserve # the order of the hooks in the final output) - for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))): + for clz in reversed(self._iter_parent_classes_with_method("add_hooks")): for hook in clz.add_hooks(self): if isinstance(hook, Var): extract_var_hooks(hook) @@ -2457,7 +2466,7 @@ def _get_tag_name(cls, component: Component) -> str | None: return None # Compute the hash based on the rendered code. - code_hash = _deterministic_hash(rendered_code) + code_hash = _hash_str(_deterministic_hash(rendered_code)) # Format the tag name including the hash. return format.format_state_name( diff --git a/reflex/environment.py b/reflex/environment.py index 5e5eb5ff873..890fe5743ed 100644 --- a/reflex/environment.py +++ b/reflex/environment.py @@ -6,7 +6,6 @@ import dataclasses import enum import importlib -import inspect import multiprocessing import os import platform @@ -159,7 +158,7 @@ def interpret_plugin_env(value: str, field_name: str) -> Plugin: msg = f"Failed to get plugin class {plugin_name!r} from module {import_path!r} for {field_name}: {e}" raise EnvironmentVarValueError(msg) from e - if not inspect.isclass(plugin_class) or not issubclass(plugin_class, Plugin): + if not isinstance(plugin_class, type) or not issubclass(plugin_class, Plugin): msg = f"Invalid plugin class: {plugin_name!r} for {field_name}. Must be a subclass of Plugin." raise EnvironmentVarValueError(msg) @@ -236,7 +235,7 @@ def interpret_env_var_value( ) for i, v in enumerate(value.split(":")) ] - if inspect.isclass(field_type) and issubclass(field_type, enum.Enum): + if isinstance(field_type, type) and issubclass(field_type, enum.Enum): return interpret_enum_env(value, field_type, field_name) msg = f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex." diff --git a/reflex/istate/data.py b/reflex/istate/data.py index 5ae2afcef40..fb07f568d68 100644 --- a/reflex/istate/data.py +++ b/reflex/istate/data.py @@ -83,6 +83,11 @@ def from_router_data(cls, router_data: dict) -> "HeaderData": ) +@serializer(to=dict) +def _serialize_header_data(obj: HeaderData) -> dict: + return dataclasses.asdict(obj) + + @serializer(to=dict) def serialize_frozen_dict_str_str(obj: _FrozenDictStrStr) -> dict: """Serialize a _FrozenDictStrStr object to a dict. @@ -165,6 +170,11 @@ def from_router_data(cls, router_data: dict) -> "PageData": ) +@serializer(to=dict) +def _serialize_page_data(obj: PageData) -> dict: + return dataclasses.asdict(obj) + + @dataclasses.dataclass(frozen=True) class SessionData: """An object containing session data.""" @@ -190,6 +200,11 @@ def from_router_data(cls, router_data: dict) -> "SessionData": ) +@serializer(to=dict) +def _serialize_session_data(obj: SessionData) -> dict: + return dataclasses.asdict(obj) + + @dataclasses.dataclass(frozen=True) class RouterData: """An object containing RouterData.""" diff --git a/reflex/plugins/sitemap.py b/reflex/plugins/sitemap.py index 377f4844591..26e14947b46 100644 --- a/reflex/plugins/sitemap.py +++ b/reflex/plugins/sitemap.py @@ -5,8 +5,7 @@ from pathlib import Path from types import SimpleNamespace from typing import TYPE_CHECKING, Literal, TypedDict -from xml.dom import minidom -from xml.etree.ElementTree import Element, SubElement, tostring +from xml.etree.ElementTree import Element, SubElement, indent, tostring from typing_extensions import NotRequired @@ -104,10 +103,8 @@ def generate_xml(links: Sequence[SitemapLink]) -> str: if (priority := link.get("priority")) is not None: priority_element = SubElement(url, "priority") priority_element.text = str(priority) - - rough_string = tostring(urlset, "utf-8") - reparsed = minidom.parseString(rough_string) - return reparsed.toprettyxml(indent=" ") + indent(urlset, " ") + return tostring(urlset, encoding="utf-8", xml_declaration=True).decode("utf-8") def is_route_dynamic(route: str) -> bool: diff --git a/reflex/reflex.py b/reflex/reflex.py index e2315406a74..cb7f08d4e35 100644 --- a/reflex/reflex.py +++ b/reflex/reflex.py @@ -360,7 +360,13 @@ def run( default=False, help="Run the command without making any changes.", ) -def compile(dry: bool): +@click.option( + "--rich/--no-rich", + default=True, + is_flag=True, + help="Whether to use rich progress bars.", +) +def compile(dry: bool, rich: bool): """Compile the app in the current directory.""" import time @@ -371,7 +377,7 @@ def compile(dry: bool): _init(name=get_config().app_name) get_config(reload=True) starting_time = time.monotonic() - prerequisites.get_compiled_app(dry_run=dry) + prerequisites.get_compiled_app(dry_run=dry, use_rich=rich) elapsed_time = time.monotonic() - starting_time console.success(f"App compiled successfully in {elapsed_time:.3f} seconds.") diff --git a/reflex/state.py b/reflex/state.py index 908ac30a241..b0934d1ff5a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1810,7 +1810,7 @@ async def _process_event( hinted_args = value_inside_optional(hinted_args) if ( isinstance(value, dict) - and inspect.isclass(hinted_args) + and isinstance(hinted_args, type) and not types.is_generic_alias(hinted_args) # py3.10 ): if issubclass(hinted_args, Model): @@ -2341,7 +2341,7 @@ def _serialize_type(type_: Any) -> str: Returns: The serialized type. """ - if not inspect.isclass(type_): + if not isinstance(type_, type): return f"{type_}" return f"{type_.__module__}.{type_.__qualname__}" diff --git a/reflex/testing.py b/reflex/testing.py index 2b229461d35..53f9f82d6ae 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -472,7 +472,7 @@ def get_app_global_source(key: str, value: Any): Returns: The rendered app global code. """ - if not inspect.isclass(value) and not inspect.isfunction(value): + if not isinstance(value, type) and not inspect.isfunction(value): return f"{key} = {value!r}" return inspect.getsource(value) diff --git a/reflex/utils/console.py b/reflex/utils/console.py index d33eef62ed3..6684b198370 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -12,7 +12,7 @@ from types import FrameType from rich.console import Console -from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn +from rich.progress import MofNCompleteColumn, Progress, TaskID, TimeElapsedColumn from rich.prompt import Prompt from reflex.constants import LogLevel @@ -395,3 +395,47 @@ def timing(msg: str): yield finally: debug(f"[white]\\[timing] {msg}: {time.time() - start:.2f}s[/white]") + + +class PoorProgress: + """A poor man's progress bar.""" + + def __init__(self): + """Initialize the progress bar.""" + super().__init__() + self.tasks = {} + self.progress = 0 + self.total = 0 + + def add_task(self, task: str, total: int): + """Add a task to the progress bar. + + Args: + task: The task name. + total: The total number of steps for the task. + + Returns: + The task ID. + """ + self.total = total + task_id = TaskID(len(self.tasks)) + self.tasks[task_id] = {"total": total, "current": 0} + return task_id + + def advance(self, task: TaskID, advance: int = 1): + """Advance the progress of a task. + + Args: + task: The task ID. + advance: The number of steps to advance. + """ + if task in self.tasks: + self.tasks[task]["current"] += advance + self.progress += advance + _console.print(f"Progress: {self.progress}/{self.total}") + + def start(self): + """Start the progress bar.""" + + def stop(self): + """Stop the progress bar.""" diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py index eb7ca99f420..03d7bf79b03 100644 --- a/reflex/utils/prerequisites.py +++ b/reflex/utils/prerequisites.py @@ -248,6 +248,7 @@ def get_compiled_app( prerender_routes: bool = False, dry_run: bool = False, check_if_schema_up_to_date: bool = False, + use_rich: bool = True, ) -> ModuleType: """Get the app module based on the default config after first compiling it. @@ -256,6 +257,7 @@ def get_compiled_app( prerender_routes: Whether to prerender routes. dry_run: If True, do not write the compiled app to disk. check_if_schema_up_to_date: If True, check if the schema is up to date. + use_rich: Whether to use rich progress bars. Returns: The compiled app based on the default config. @@ -263,7 +265,7 @@ def get_compiled_app( app, app_module = get_and_validate_app( reload=reload, check_if_schema_up_to_date=check_if_schema_up_to_date ) - app._compile(prerender_routes=prerender_routes, dry_run=dry_run) + app._compile(prerender_routes=prerender_routes, dry_run=dry_run, use_rich=use_rich) return app_module diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py index 6aa6ac67dee..dd9fd50bb7f 100644 --- a/reflex/utils/pyi_generator.py +++ b/reflex/utils/pyi_generator.py @@ -547,7 +547,7 @@ def _generate_component_create_functiondef( kwargs.extend(prop_kwargs) def figure_out_return_type(annotation: Any): - if inspect.isclass(annotation) and issubclass(annotation, inspect._empty): + if isinstance(annotation, type) and issubclass(annotation, inspect._empty): return ast.Name(id="EventType[Any]") if not isinstance(annotation, str) and get_origin(annotation) is tuple: @@ -1181,7 +1181,7 @@ def _scan_file(self, module_path: Path) -> tuple[str, str] | None: class_names = { name: obj for name, obj in vars(module).items() - if inspect.isclass(obj) + if isinstance(obj, type) and ( rx_types.safe_issubclass(obj, Component) or rx_types.safe_issubclass(obj, SimpleNamespace) diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 5ae7ddd5ae1..fc351f1b8d1 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -187,7 +187,7 @@ def get_serializer(type_: type) -> Serializer | None: # If the type is not registered, check if it is a subclass of a registered type. for registered_type, serializer in reversed(SERIALIZERS.items()): - if types._issubclass(type_, registered_type): + if issubclass(type_, registered_type): return serializer # If there is no serializer, return None. @@ -211,7 +211,7 @@ def get_serializer_type(type_: type) -> type | None: # If the type is not registered, check if it is a subclass of a registered type. for registered_type, serializer in reversed(SERIALIZER_TYPES.items()): - if types._issubclass(type_, registered_type): + if issubclass(type_, registered_type): return serializer # If there is no serializer, return None. @@ -244,11 +244,11 @@ def can_serialize(type_: type, into_type: type | None = None) -> bool: Returns: Whether there is a serializer for the type. """ - return has_serializer(type_, into_type) or ( + return ( isinstance(type_, type) and dataclasses.is_dataclass(type_) and (into_type is None or into_type is dict) - ) + ) or has_serializer(type_, into_type) @serializer(to=str) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 6b83486cf7f..0a92d9c209c 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -52,11 +52,8 @@ GenericType = type | _GenericAlias # Valid state var types. -JSONType = {str, int, float, bool} -PrimitiveType = int | float | bool | str | list | dict | set | tuple PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple) -StateVar = PrimitiveType | Base | None -StateIterVar = list | set | tuple +StateVarTypes = (*PrimitiveTypes, Base, type(None)) if TYPE_CHECKING: from reflex.vars.base import Var @@ -401,6 +398,8 @@ def get_field_type(cls: GenericType, field_name: str) -> GenericType | None: Returns: The type of the field, if it exists, else None. """ + if (fields := getattr(cls, "_fields", None)) is not None and field_name in fields: + return fields[field_name].annotated_type if ( hasattr(cls, "__fields__") and field_name in cls.__fields__ @@ -858,7 +857,7 @@ def is_valid_var_type(type_: type) -> bool: if is_union(type_): return all(is_valid_var_type(arg) for arg in get_args(type_)) return ( - _issubclass(type_, StateVar) + issubclass(type_, StateVarTypes) or serializers.has_serializer(type_) or dataclasses.is_dataclass(type_) ) @@ -993,11 +992,6 @@ def validate_literal(key: str, value: Any, expected_type: type, comp_name: str): raise ValueError(msg) -# Store this here for performance. -StateBases = get_base_class(StateVar) -StateIterBases = get_base_class(StateIterVar) - - def safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]): """Check if a class is a subclass of another class. Returns False if internal error occurs. diff --git a/reflex/vars/base.py b/reflex/vars/base.py index 020d09aea05..554ddce899d 100644 --- a/reflex/vars/base.py +++ b/reflex/vars/base.py @@ -366,7 +366,7 @@ def can_use_in_object_var(cls: GenericType) -> bool: if types.is_union(cls): return all(can_use_in_object_var(t) for t in types.get_args(cls)) return ( - inspect.isclass(cls) + isinstance(cls, type) and not safe_issubclass(cls, Var) and serializers.can_serialize(cls, dict) ) @@ -516,6 +516,17 @@ def _get_all_var_data(self) -> VarData | None: """ return self._var_data + def __deepcopy__(self, memo: dict[int, Any]) -> Self: + """Deepcopy the var. + + Args: + memo: The memo dictionary to use for the deepcopy. + + Returns: + A deepcopy of the var. + """ + return self + def equals(self, other: Var) -> bool: """Check if two vars are equal. @@ -795,7 +806,7 @@ def to( if can_use_in_object_var(output): return self.to(ObjectVar, output) - if inspect.isclass(output): + if isinstance(output, type): for var_subclass in _var_subclasses[::-1]: if safe_issubclass(output, var_subclass.var_subclass): current_var_type = self._var_type @@ -891,7 +902,7 @@ def guess_type(self) -> Var: args = get_args(var_type) fixed_type = unionize(*(type(arg) for arg in args)) - if not inspect.isclass(fixed_type): + if not isinstance(fixed_type, type): msg = f"Unsupported type {var_type} for guess_type." raise TypeError(msg) @@ -1153,18 +1164,6 @@ def _without_data(self): """ return dataclasses.replace(self, _var_data=None) - def __get__(self, instance: Any, owner: Any): - """Get the var. - - Args: - instance: The instance to get the var from. - owner: The owner of the var. - - Returns: - The var. - """ - return self - def _decode(self) -> Any: """Decode Var as a python value. @@ -1409,7 +1408,7 @@ def __init_subclass__(cls, **kwargs): bases = cls.__bases__ bases_normalized = [ - base if inspect.isclass(base) else get_origin(base) for base in bases + base if isinstance(base, type) else get_origin(base) for base in bases ] possible_bases = [ diff --git a/reflex/vars/object.py b/reflex/vars/object.py index 3d514926624..53f682e8ce6 100644 --- a/reflex/vars/object.py +++ b/reflex/vars/object.py @@ -6,7 +6,6 @@ import dataclasses import typing from collections.abc import Mapping -from inspect import isclass from typing import ( Any, NoReturn, @@ -328,7 +327,10 @@ def __getattr__(self, name: str) -> Var: if ( is_typeddict(fixed_type) - or (isclass(fixed_type) and not safe_issubclass(fixed_type, Mapping)) + or ( + isinstance(fixed_type, type) + and not safe_issubclass(fixed_type, Mapping) + ) or (fixed_type in types.UnionTypes) ): attribute_type = get_attribute_access_type(var_type, name) diff --git a/tests/units/plugins/test_sitemap.py b/tests/units/plugins/test_sitemap.py index f79a4144da7..ba1af78807e 100644 --- a/tests/units/plugins/test_sitemap.py +++ b/tests/units/plugins/test_sitemap.py @@ -11,7 +11,8 @@ def test_generate_xml_empty_links(): """Test generate_xml with an empty list of links.""" xml_output = generate_xml([]) - expected = '\n\n' + expected = """ +""" assert xml_output == expected @@ -19,13 +20,12 @@ def test_generate_xml_single_link_loc_only(): """Test generate_xml with a single link having only loc.""" links: list[SitemapLink] = [{"loc": "https://example.com"}] xml_output = generate_xml(links) - expected = """ + expected = """ https://example.com - -""" +""" assert xml_output == expected @@ -47,7 +47,7 @@ def test_generate_xml_multiple_links_all_fields(): }, ] xml_output = generate_xml(links) - expected = """ + expected = """ https://example.com/page1 @@ -61,8 +61,7 @@ def test_generate_xml_multiple_links_all_fields(): 2023-01-01T00:00:00 0.5 - -""" +""" assert xml_output == expected From 745ca1b49570919fed2a856cfd97d64b15993e04 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 28 Aug 2025 13:51:47 -0700 Subject: [PATCH 2/3] fix few stuff --- reflex/components/component.py | 2 +- reflex/istate/data.py | 2 +- reflex/utils/types.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/reflex/components/component.py b/reflex/components/component.py index a448e0aa398..2592b7f662a 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -1516,7 +1516,7 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen Returns: A sequence of parent classes that define the method (differently than the base). """ - current_class_method = getattr(cls, method, None) + current_class_method = getattr(Component, method, None) seen_methods = ( {current_class_method} if current_class_method is not None else set() ) diff --git a/reflex/istate/data.py b/reflex/istate/data.py index fb07f568d68..3d09a35e97a 100644 --- a/reflex/istate/data.py +++ b/reflex/istate/data.py @@ -85,7 +85,7 @@ def from_router_data(cls, router_data: dict) -> "HeaderData": @serializer(to=dict) def _serialize_header_data(obj: HeaderData) -> dict: - return dataclasses.asdict(obj) + return {k.name: getattr(obj, k.name) for k in dataclasses.fields(obj)} @serializer(to=dict) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 0a92d9c209c..93e253b22c5 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -856,6 +856,9 @@ def is_valid_var_type(type_: type) -> bool: if is_union(type_): return all(is_valid_var_type(arg) for arg in get_args(type_)) + + type_ = origin if (origin := get_origin(type_)) is not None else type_ + return ( issubclass(type_, StateVarTypes) or serializers.has_serializer(type_) From 7746052bcc55b163edf25cc48fb55616258024e1 Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 28 Aug 2025 13:58:44 -0700 Subject: [PATCH 3/3] add total --- reflex/utils/console.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/utils/console.py b/reflex/utils/console.py index 6684b198370..d498b09d323 100644 --- a/reflex/utils/console.py +++ b/reflex/utils/console.py @@ -417,7 +417,7 @@ def add_task(self, task: str, total: int): Returns: The task ID. """ - self.total = total + self.total += total task_id = TaskID(len(self.tasks)) self.tasks[task_id] = {"total": total, "current": 0} return task_id