Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions reflex/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,12 +1100,18 @@ def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> No
for substate in state.class_subclasses:
self._validate_var_dependencies(substate)

def _compile(self, prerender_routes: bool = False, dry_run: bool = False):
def _compile(
self,
prerender_routes: bool = False,
dry_run: bool = False,
use_rich: bool = True,
):
"""Compile the app and output it to the pages folder.

Args:
prerender_routes: Whether to prerender the routes.
dry_run: Whether to compile the app without saving it.
use_rich: Whether to use rich progress bars.

Raises:
ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined.
Expand Down Expand Up @@ -1171,10 +1177,14 @@ def get_compilation_time() -> str:
return

# Create a progress bar.
progress = Progress(
*Progress.get_default_columns()[:-1],
MofNCompleteColumn(),
TimeElapsedColumn(),
progress = (
Progress(
*Progress.get_default_columns()[:-1],
MofNCompleteColumn(),
TimeElapsedColumn(),
)
if use_rich
else console.PoorProgress()
)

# try to be somewhat accurate - but still not 100%
Expand Down
2 changes: 1 addition & 1 deletion reflex/compiler/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def context_template(
export const DispatchContext = createContext(null);
export const StateContexts = {{{state_contexts_str}}};
export const EventLoopContext = createContext(null);
export const clientStorage = {"{}" if client_storage is None else json_dumps(client_storage)}
export const clientStorage = {"{}" if client_storage is None else json.dumps(client_storage)}

{state_str}

Expand Down
11 changes: 7 additions & 4 deletions reflex/compiler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,13 @@ def compile_state(state: type[BaseState]) -> dict:

def _compile_client_storage_field(
field: Field,
) -> tuple[
type[Cookie] | type[LocalStorage] | type[SessionStorage] | None,
dict[str, Any] | None,
]:
) -> (
tuple[
type[Cookie] | type[LocalStorage] | type[SessionStorage],
dict[str, Any],
]
| tuple[None, None]
):
"""Compile the given cookie, local_storage or session_storage field.

Args:
Expand Down
75 changes: 42 additions & 33 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def evaluate_style_namespaces(style: ComponentStyle) -> dict:

# Map from component to styling.
ComponentStyle = dict[str | type[BaseComponent] | Callable | ComponentNamespace, Any]
ComponentChild = types.PrimitiveType | Var | BaseComponent
ComponentChildTypes = (*types.PrimitiveTypes, Var, BaseComponent, type(None))


Expand Down Expand Up @@ -480,7 +479,21 @@ def _components_from(
return ()


def _deterministic_hash(value: object) -> int:
def _hash_str(value: str) -> str:
return md5(f'"{value}"'.encode(), usedforsecurity=False).hexdigest()


def _hash_sequence(value: Sequence) -> str:
return _hash_str(str([_deterministic_hash(v) for v in value]))


def _hash_dict(value: dict) -> str:
return _hash_sequence(
sorted([(k, _deterministic_hash(v)) for k, v in value.items()])
)


def _deterministic_hash(value: object) -> str:
"""Hash a rendered dictionary.

Args:
Expand All @@ -492,37 +505,28 @@ def _deterministic_hash(value: object) -> int:
Raises:
TypeError: If the value is not hashable.
"""
if isinstance(value, BaseComponent):
# If the value is a component, hash its rendered code.
rendered_code = value.render()
return _deterministic_hash(rendered_code)
if isinstance(value, Var):
return _deterministic_hash((value._js_expr, value._get_all_var_data()))
if isinstance(value, VarData):
return _deterministic_hash(dataclasses.asdict(value))
if isinstance(value, dict):
# Sort the dictionary to ensure consistent hashing.
return _deterministic_hash(
tuple(sorted((k, _deterministic_hash(v)) for k, v in value.items()))
)
if isinstance(value, int):
if value is None:
# Hash None as a special case.
return "None"
if isinstance(value, (int, float, enum.Enum)):
# Hash numbers and booleans directly.
return int(value)
if isinstance(value, float):
return _deterministic_hash(str(value))
return str(value)
if isinstance(value, str):
return int(md5(f'"{value}"'.encode()).hexdigest(), 16)
return _hash_str(value)
if isinstance(value, dict):
return _hash_dict(value)
if isinstance(value, (tuple, list)):
# Hash tuples by hashing each element.
return _deterministic_hash(
"[" + ",".join(map(str, map(_deterministic_hash, value))) + "]"
return _hash_sequence(value)
if isinstance(value, Var):
return _hash_str(
str((value._js_expr, _deterministic_hash(value._get_all_var_data())))
)
if isinstance(value, enum.Enum):
# Hash enums by their name.
return _deterministic_hash(str(value))
if value is None:
# Hash None as a special case.
return _deterministic_hash("None")
if isinstance(value, VarData):
return _hash_dict(dataclasses.asdict(value))
if isinstance(value, BaseComponent):
# If the value is a component, hash its rendered code.
return _hash_dict(value.render())

msg = (
f"Cannot hash value `{value}` of type `{type(value).__name__}`. "
Expand Down Expand Up @@ -1038,9 +1042,13 @@ def _get_component_prop_names(cls) -> set[str]:
name
for name in cls.get_fields()
if name in cls.get_props()
and types._issubclass(
types.value_inside_optional(types.get_field_type(cls, name)), Component
and isinstance(
field_type := types.value_inside_optional(
types.get_field_type(cls, name)
),
type,
)
and issubclass(field_type, Component)
}

def _get_components_in_props(self) -> Sequence[BaseComponent]:
Expand Down Expand Up @@ -1508,8 +1516,9 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen
Returns:
A sequence of parent classes that define the method (differently than the base).
"""
current_class_method = getattr(Component, method, None)
seen_methods = (
{getattr(Component, method)} if hasattr(Component, method) else set()
{current_class_method} if current_class_method is not None else set()
)
clzs: list[type[Component]] = []
for clz in cls.mro():
Expand Down Expand Up @@ -1807,7 +1816,7 @@ def extract_var_hooks(hook: Var):

# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
for clz in reversed(self._iter_parent_classes_with_method("add_hooks")):
for hook in clz.add_hooks(self):
if isinstance(hook, Var):
extract_var_hooks(hook)
Expand Down Expand Up @@ -2457,7 +2466,7 @@ def _get_tag_name(cls, component: Component) -> str | None:
return None

# Compute the hash based on the rendered code.
code_hash = _deterministic_hash(rendered_code)
code_hash = _hash_str(_deterministic_hash(rendered_code))
Comment thread
adhami3310 marked this conversation as resolved.

# Format the tag name including the hash.
return format.format_state_name(
Expand Down
5 changes: 2 additions & 3 deletions reflex/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import dataclasses
import enum
import importlib
import inspect
import multiprocessing
import os
import platform
Expand Down Expand Up @@ -159,7 +158,7 @@ def interpret_plugin_env(value: str, field_name: str) -> Plugin:
msg = f"Failed to get plugin class {plugin_name!r} from module {import_path!r} for {field_name}: {e}"
raise EnvironmentVarValueError(msg) from e

if not inspect.isclass(plugin_class) or not issubclass(plugin_class, Plugin):
if not isinstance(plugin_class, type) or not issubclass(plugin_class, Plugin):
msg = f"Invalid plugin class: {plugin_name!r} for {field_name}. Must be a subclass of Plugin."
raise EnvironmentVarValueError(msg)

Expand Down Expand Up @@ -236,7 +235,7 @@ def interpret_env_var_value(
)
for i, v in enumerate(value.split(":"))
]
if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
return interpret_enum_env(value, field_type, field_name)

msg = f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
Expand Down
15 changes: 15 additions & 0 deletions reflex/istate/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def from_router_data(cls, router_data: dict) -> "HeaderData":
)


@serializer(to=dict)
def _serialize_header_data(obj: HeaderData) -> dict:
return {k.name: getattr(obj, k.name) for k in dataclasses.fields(obj)}


@serializer(to=dict)
def serialize_frozen_dict_str_str(obj: _FrozenDictStrStr) -> dict:
"""Serialize a _FrozenDictStrStr object to a dict.
Expand Down Expand Up @@ -165,6 +170,11 @@ def from_router_data(cls, router_data: dict) -> "PageData":
)


@serializer(to=dict)
def _serialize_page_data(obj: PageData) -> dict:
return dataclasses.asdict(obj)


@dataclasses.dataclass(frozen=True)
class SessionData:
"""An object containing session data."""
Expand All @@ -190,6 +200,11 @@ def from_router_data(cls, router_data: dict) -> "SessionData":
)


@serializer(to=dict)
def _serialize_session_data(obj: SessionData) -> dict:
return dataclasses.asdict(obj)


@dataclasses.dataclass(frozen=True)
class RouterData:
"""An object containing RouterData."""
Expand Down
9 changes: 3 additions & 6 deletions reflex/plugins/sitemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING, Literal, TypedDict
from xml.dom import minidom
from xml.etree.ElementTree import Element, SubElement, tostring
from xml.etree.ElementTree import Element, SubElement, indent, tostring

from typing_extensions import NotRequired

Expand Down Expand Up @@ -104,10 +103,8 @@ def generate_xml(links: Sequence[SitemapLink]) -> str:
if (priority := link.get("priority")) is not None:
priority_element = SubElement(url, "priority")
priority_element.text = str(priority)

rough_string = tostring(urlset, "utf-8")
reparsed = minidom.parseString(rough_string)
return reparsed.toprettyxml(indent=" ")
indent(urlset, " ")
return tostring(urlset, encoding="utf-8", xml_declaration=True).decode("utf-8")


def is_route_dynamic(route: str) -> bool:
Expand Down
10 changes: 8 additions & 2 deletions reflex/reflex.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,13 @@ def run(
default=False,
help="Run the command without making any changes.",
)
def compile(dry: bool):
@click.option(
"--rich/--no-rich",
default=True,
is_flag=True,
help="Whether to use rich progress bars.",
)
def compile(dry: bool, rich: bool):
"""Compile the app in the current directory."""
import time

Expand All @@ -371,7 +377,7 @@ def compile(dry: bool):
_init(name=get_config().app_name)
get_config(reload=True)
starting_time = time.monotonic()
prerequisites.get_compiled_app(dry_run=dry)
prerequisites.get_compiled_app(dry_run=dry, use_rich=rich)
elapsed_time = time.monotonic() - starting_time
console.success(f"App compiled successfully in {elapsed_time:.3f} seconds.")

Expand Down
4 changes: 2 additions & 2 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,7 +1810,7 @@ async def _process_event(
hinted_args = value_inside_optional(hinted_args)
if (
isinstance(value, dict)
and inspect.isclass(hinted_args)
and isinstance(hinted_args, type)
and not types.is_generic_alias(hinted_args) # py3.10
):
if issubclass(hinted_args, Model):
Expand Down Expand Up @@ -2341,7 +2341,7 @@ def _serialize_type(type_: Any) -> str:
Returns:
The serialized type.
"""
if not inspect.isclass(type_):
if not isinstance(type_, type):
return f"{type_}"
return f"{type_.__module__}.{type_.__qualname__}"

Expand Down
2 changes: 1 addition & 1 deletion reflex/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def get_app_global_source(key: str, value: Any):
Returns:
The rendered app global code.
"""
if not inspect.isclass(value) and not inspect.isfunction(value):
if not isinstance(value, type) and not inspect.isfunction(value):
return f"{key} = {value!r}"
return inspect.getsource(value)

Expand Down
46 changes: 45 additions & 1 deletion reflex/utils/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from types import FrameType

from rich.console import Console
from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
from rich.progress import MofNCompleteColumn, Progress, TaskID, TimeElapsedColumn
from rich.prompt import Prompt

from reflex.constants import LogLevel
Expand Down Expand Up @@ -395,3 +395,47 @@ def timing(msg: str):
yield
finally:
debug(f"[white]\\[timing] {msg}: {time.time() - start:.2f}s[/white]")


class PoorProgress:
"""A poor man's progress bar."""

def __init__(self):
"""Initialize the progress bar."""
super().__init__()
self.tasks = {}
self.progress = 0
self.total = 0

def add_task(self, task: str, total: int):
"""Add a task to the progress bar.

Args:
task: The task name.
total: The total number of steps for the task.

Returns:
The task ID.
"""
self.total += total
task_id = TaskID(len(self.tasks))
self.tasks[task_id] = {"total": total, "current": 0}
return task_id

def advance(self, task: TaskID, advance: int = 1):
"""Advance the progress of a task.

Args:
task: The task ID.
advance: The number of steps to advance.
"""
if task in self.tasks:
self.tasks[task]["current"] += advance
self.progress += advance
_console.print(f"Progress: {self.progress}/{self.total}")

def start(self):
"""Start the progress bar."""

def stop(self):
"""Stop the progress bar."""
Loading
Loading