diff --git a/reflex/app.py b/reflex/app.py
index dae1a80fd11..a298d1beef0 100644
--- a/reflex/app.py
+++ b/reflex/app.py
@@ -1100,12 +1100,18 @@ def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> No
for substate in state.class_subclasses:
self._validate_var_dependencies(substate)
- def _compile(self, prerender_routes: bool = False, dry_run: bool = False):
+ def _compile(
+ self,
+ prerender_routes: bool = False,
+ dry_run: bool = False,
+ use_rich: bool = True,
+ ):
"""Compile the app and output it to the pages folder.
Args:
prerender_routes: Whether to prerender the routes.
dry_run: Whether to compile the app without saving it.
+ use_rich: Whether to use rich progress bars.
Raises:
ReflexRuntimeError: When any page uses state, but no rx.State subclass is defined.
@@ -1171,10 +1177,14 @@ def get_compilation_time() -> str:
return
# Create a progress bar.
- progress = Progress(
- *Progress.get_default_columns()[:-1],
- MofNCompleteColumn(),
- TimeElapsedColumn(),
+ progress = (
+ Progress(
+ *Progress.get_default_columns()[:-1],
+ MofNCompleteColumn(),
+ TimeElapsedColumn(),
+ )
+ if use_rich
+ else console.PoorProgress()
)
# try to be somewhat accurate - but still not 100%
diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py
index ba286a8decc..9712f026b25 100644
--- a/reflex/compiler/templates.py
+++ b/reflex/compiler/templates.py
@@ -362,7 +362,7 @@ def context_template(
export const DispatchContext = createContext(null);
export const StateContexts = {{{state_contexts_str}}};
export const EventLoopContext = createContext(null);
-export const clientStorage = {"{}" if client_storage is None else json_dumps(client_storage)}
+export const clientStorage = {"{}" if client_storage is None else json.dumps(client_storage)}
{state_str}
diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py
index f85a5fe84c0..3a8a33a5a9c 100644
--- a/reflex/compiler/utils.py
+++ b/reflex/compiler/utils.py
@@ -220,10 +220,13 @@ def compile_state(state: type[BaseState]) -> dict:
def _compile_client_storage_field(
field: Field,
-) -> tuple[
- type[Cookie] | type[LocalStorage] | type[SessionStorage] | None,
- dict[str, Any] | None,
-]:
+) -> (
+ tuple[
+ type[Cookie] | type[LocalStorage] | type[SessionStorage],
+ dict[str, Any],
+ ]
+ | tuple[None, None]
+):
"""Compile the given cookie, local_storage or session_storage field.
Args:
diff --git a/reflex/components/component.py b/reflex/components/component.py
index 5db03404668..2592b7f662a 100644
--- a/reflex/components/component.py
+++ b/reflex/components/component.py
@@ -419,7 +419,6 @@ def evaluate_style_namespaces(style: ComponentStyle) -> dict:
# Map from component to styling.
ComponentStyle = dict[str | type[BaseComponent] | Callable | ComponentNamespace, Any]
-ComponentChild = types.PrimitiveType | Var | BaseComponent
ComponentChildTypes = (*types.PrimitiveTypes, Var, BaseComponent, type(None))
@@ -480,7 +479,21 @@ def _components_from(
return ()
-def _deterministic_hash(value: object) -> int:
+def _hash_str(value: str) -> str:
+ return md5(f'"{value}"'.encode(), usedforsecurity=False).hexdigest()
+
+
+def _hash_sequence(value: Sequence) -> str:
+ return _hash_str(str([_deterministic_hash(v) for v in value]))
+
+
+def _hash_dict(value: dict) -> str:
+ return _hash_sequence(
+ sorted([(k, _deterministic_hash(v)) for k, v in value.items()])
+ )
+
+
+def _deterministic_hash(value: object) -> str:
"""Hash a rendered dictionary.
Args:
@@ -492,37 +505,28 @@ def _deterministic_hash(value: object) -> int:
Raises:
TypeError: If the value is not hashable.
"""
- if isinstance(value, BaseComponent):
- # If the value is a component, hash its rendered code.
- rendered_code = value.render()
- return _deterministic_hash(rendered_code)
- if isinstance(value, Var):
- return _deterministic_hash((value._js_expr, value._get_all_var_data()))
- if isinstance(value, VarData):
- return _deterministic_hash(dataclasses.asdict(value))
- if isinstance(value, dict):
- # Sort the dictionary to ensure consistent hashing.
- return _deterministic_hash(
- tuple(sorted((k, _deterministic_hash(v)) for k, v in value.items()))
- )
- if isinstance(value, int):
+ if value is None:
+ # Hash None as a special case.
+ return "None"
+ if isinstance(value, (int, float, enum.Enum)):
# Hash numbers and booleans directly.
- return int(value)
- if isinstance(value, float):
- return _deterministic_hash(str(value))
+ return str(value)
if isinstance(value, str):
- return int(md5(f'"{value}"'.encode()).hexdigest(), 16)
+ return _hash_str(value)
+ if isinstance(value, dict):
+ return _hash_dict(value)
if isinstance(value, (tuple, list)):
# Hash tuples by hashing each element.
- return _deterministic_hash(
- "[" + ",".join(map(str, map(_deterministic_hash, value))) + "]"
+ return _hash_sequence(value)
+ if isinstance(value, Var):
+ return _hash_str(
+ str((value._js_expr, _deterministic_hash(value._get_all_var_data())))
)
- if isinstance(value, enum.Enum):
- # Hash enums by their name.
- return _deterministic_hash(str(value))
- if value is None:
- # Hash None as a special case.
- return _deterministic_hash("None")
+ if isinstance(value, VarData):
+ return _hash_dict(dataclasses.asdict(value))
+ if isinstance(value, BaseComponent):
+ # If the value is a component, hash its rendered code.
+ return _hash_dict(value.render())
msg = (
f"Cannot hash value `{value}` of type `{type(value).__name__}`. "
@@ -1038,9 +1042,13 @@ def _get_component_prop_names(cls) -> set[str]:
name
for name in cls.get_fields()
if name in cls.get_props()
- and types._issubclass(
- types.value_inside_optional(types.get_field_type(cls, name)), Component
+ and isinstance(
+ field_type := types.value_inside_optional(
+ types.get_field_type(cls, name)
+ ),
+ type,
)
+ and issubclass(field_type, Component)
}
def _get_components_in_props(self) -> Sequence[BaseComponent]:
@@ -1508,8 +1516,9 @@ def _iter_parent_classes_with_method(cls, method: str) -> Sequence[type[Componen
Returns:
A sequence of parent classes that define the method (differently than the base).
"""
+ current_class_method = getattr(Component, method, None)
seen_methods = (
- {getattr(Component, method)} if hasattr(Component, method) else set()
+ {current_class_method} if current_class_method is not None else set()
)
clzs: list[type[Component]] = []
for clz in cls.mro():
@@ -1807,7 +1816,7 @@ def extract_var_hooks(hook: Var):
# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
- for clz in reversed(tuple(self._iter_parent_classes_with_method("add_hooks"))):
+ for clz in reversed(self._iter_parent_classes_with_method("add_hooks")):
for hook in clz.add_hooks(self):
if isinstance(hook, Var):
extract_var_hooks(hook)
@@ -2457,7 +2466,7 @@ def _get_tag_name(cls, component: Component) -> str | None:
return None
# Compute the hash based on the rendered code.
- code_hash = _deterministic_hash(rendered_code)
+ code_hash = _hash_str(_deterministic_hash(rendered_code))
# Format the tag name including the hash.
return format.format_state_name(
diff --git a/reflex/environment.py b/reflex/environment.py
index 5e5eb5ff873..890fe5743ed 100644
--- a/reflex/environment.py
+++ b/reflex/environment.py
@@ -6,7 +6,6 @@
import dataclasses
import enum
import importlib
-import inspect
import multiprocessing
import os
import platform
@@ -159,7 +158,7 @@ def interpret_plugin_env(value: str, field_name: str) -> Plugin:
msg = f"Failed to get plugin class {plugin_name!r} from module {import_path!r} for {field_name}: {e}"
raise EnvironmentVarValueError(msg) from e
- if not inspect.isclass(plugin_class) or not issubclass(plugin_class, Plugin):
+ if not isinstance(plugin_class, type) or not issubclass(plugin_class, Plugin):
msg = f"Invalid plugin class: {plugin_name!r} for {field_name}. Must be a subclass of Plugin."
raise EnvironmentVarValueError(msg)
@@ -236,7 +235,7 @@ def interpret_env_var_value(
)
for i, v in enumerate(value.split(":"))
]
- if inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
+ if isinstance(field_type, type) and issubclass(field_type, enum.Enum):
return interpret_enum_env(value, field_type, field_name)
msg = f"Invalid type for environment variable {field_name}: {field_type}. This is probably an issue in Reflex."
diff --git a/reflex/istate/data.py b/reflex/istate/data.py
index 5ae2afcef40..3d09a35e97a 100644
--- a/reflex/istate/data.py
+++ b/reflex/istate/data.py
@@ -83,6 +83,11 @@ def from_router_data(cls, router_data: dict) -> "HeaderData":
)
+@serializer(to=dict)
+def _serialize_header_data(obj: HeaderData) -> dict:
+ return {k.name: getattr(obj, k.name) for k in dataclasses.fields(obj)}
+
+
@serializer(to=dict)
def serialize_frozen_dict_str_str(obj: _FrozenDictStrStr) -> dict:
"""Serialize a _FrozenDictStrStr object to a dict.
@@ -165,6 +170,11 @@ def from_router_data(cls, router_data: dict) -> "PageData":
)
+@serializer(to=dict)
+def _serialize_page_data(obj: PageData) -> dict:
+ return dataclasses.asdict(obj)
+
+
@dataclasses.dataclass(frozen=True)
class SessionData:
"""An object containing session data."""
@@ -190,6 +200,11 @@ def from_router_data(cls, router_data: dict) -> "SessionData":
)
+@serializer(to=dict)
+def _serialize_session_data(obj: SessionData) -> dict:
+ return dataclasses.asdict(obj)
+
+
@dataclasses.dataclass(frozen=True)
class RouterData:
"""An object containing RouterData."""
diff --git a/reflex/plugins/sitemap.py b/reflex/plugins/sitemap.py
index 377f4844591..26e14947b46 100644
--- a/reflex/plugins/sitemap.py
+++ b/reflex/plugins/sitemap.py
@@ -5,8 +5,7 @@
from pathlib import Path
from types import SimpleNamespace
from typing import TYPE_CHECKING, Literal, TypedDict
-from xml.dom import minidom
-from xml.etree.ElementTree import Element, SubElement, tostring
+from xml.etree.ElementTree import Element, SubElement, indent, tostring
from typing_extensions import NotRequired
@@ -104,10 +103,8 @@ def generate_xml(links: Sequence[SitemapLink]) -> str:
if (priority := link.get("priority")) is not None:
priority_element = SubElement(url, "priority")
priority_element.text = str(priority)
-
- rough_string = tostring(urlset, "utf-8")
- reparsed = minidom.parseString(rough_string)
- return reparsed.toprettyxml(indent=" ")
+ indent(urlset, " ")
+ return tostring(urlset, encoding="utf-8", xml_declaration=True).decode("utf-8")
def is_route_dynamic(route: str) -> bool:
diff --git a/reflex/reflex.py b/reflex/reflex.py
index e2315406a74..cb7f08d4e35 100644
--- a/reflex/reflex.py
+++ b/reflex/reflex.py
@@ -360,7 +360,13 @@ def run(
default=False,
help="Run the command without making any changes.",
)
-def compile(dry: bool):
+@click.option(
+ "--rich/--no-rich",
+ default=True,
+ is_flag=True,
+ help="Whether to use rich progress bars.",
+)
+def compile(dry: bool, rich: bool):
"""Compile the app in the current directory."""
import time
@@ -371,7 +377,7 @@ def compile(dry: bool):
_init(name=get_config().app_name)
get_config(reload=True)
starting_time = time.monotonic()
- prerequisites.get_compiled_app(dry_run=dry)
+ prerequisites.get_compiled_app(dry_run=dry, use_rich=rich)
elapsed_time = time.monotonic() - starting_time
console.success(f"App compiled successfully in {elapsed_time:.3f} seconds.")
diff --git a/reflex/state.py b/reflex/state.py
index 908ac30a241..b0934d1ff5a 100644
--- a/reflex/state.py
+++ b/reflex/state.py
@@ -1810,7 +1810,7 @@ async def _process_event(
hinted_args = value_inside_optional(hinted_args)
if (
isinstance(value, dict)
- and inspect.isclass(hinted_args)
+ and isinstance(hinted_args, type)
and not types.is_generic_alias(hinted_args) # py3.10
):
if issubclass(hinted_args, Model):
@@ -2341,7 +2341,7 @@ def _serialize_type(type_: Any) -> str:
Returns:
The serialized type.
"""
- if not inspect.isclass(type_):
+ if not isinstance(type_, type):
return f"{type_}"
return f"{type_.__module__}.{type_.__qualname__}"
diff --git a/reflex/testing.py b/reflex/testing.py
index 2b229461d35..53f9f82d6ae 100644
--- a/reflex/testing.py
+++ b/reflex/testing.py
@@ -472,7 +472,7 @@ def get_app_global_source(key: str, value: Any):
Returns:
The rendered app global code.
"""
- if not inspect.isclass(value) and not inspect.isfunction(value):
+ if not isinstance(value, type) and not inspect.isfunction(value):
return f"{key} = {value!r}"
return inspect.getsource(value)
diff --git a/reflex/utils/console.py b/reflex/utils/console.py
index d33eef62ed3..d498b09d323 100644
--- a/reflex/utils/console.py
+++ b/reflex/utils/console.py
@@ -12,7 +12,7 @@
from types import FrameType
from rich.console import Console
-from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn
+from rich.progress import MofNCompleteColumn, Progress, TaskID, TimeElapsedColumn
from rich.prompt import Prompt
from reflex.constants import LogLevel
@@ -395,3 +395,47 @@ def timing(msg: str):
yield
finally:
debug(f"[white]\\[timing] {msg}: {time.time() - start:.2f}s[/white]")
+
+
+class PoorProgress:
+ """A poor man's progress bar."""
+
+ def __init__(self):
+ """Initialize the progress bar."""
+ super().__init__()
+ self.tasks = {}
+ self.progress = 0
+ self.total = 0
+
+ def add_task(self, task: str, total: int):
+ """Add a task to the progress bar.
+
+ Args:
+ task: The task name.
+ total: The total number of steps for the task.
+
+ Returns:
+ The task ID.
+ """
+ self.total += total
+ task_id = TaskID(len(self.tasks))
+ self.tasks[task_id] = {"total": total, "current": 0}
+ return task_id
+
+ def advance(self, task: TaskID, advance: int = 1):
+ """Advance the progress of a task.
+
+ Args:
+ task: The task ID.
+ advance: The number of steps to advance.
+ """
+ if task in self.tasks:
+ self.tasks[task]["current"] += advance
+ self.progress += advance
+ _console.print(f"Progress: {self.progress}/{self.total}")
+
+ def start(self):
+ """Start the progress bar."""
+
+ def stop(self):
+ """Stop the progress bar."""
diff --git a/reflex/utils/prerequisites.py b/reflex/utils/prerequisites.py
index eb7ca99f420..03d7bf79b03 100644
--- a/reflex/utils/prerequisites.py
+++ b/reflex/utils/prerequisites.py
@@ -248,6 +248,7 @@ def get_compiled_app(
prerender_routes: bool = False,
dry_run: bool = False,
check_if_schema_up_to_date: bool = False,
+ use_rich: bool = True,
) -> ModuleType:
"""Get the app module based on the default config after first compiling it.
@@ -256,6 +257,7 @@ def get_compiled_app(
prerender_routes: Whether to prerender routes.
dry_run: If True, do not write the compiled app to disk.
check_if_schema_up_to_date: If True, check if the schema is up to date.
+ use_rich: Whether to use rich progress bars.
Returns:
The compiled app based on the default config.
@@ -263,7 +265,7 @@ def get_compiled_app(
app, app_module = get_and_validate_app(
reload=reload, check_if_schema_up_to_date=check_if_schema_up_to_date
)
- app._compile(prerender_routes=prerender_routes, dry_run=dry_run)
+ app._compile(prerender_routes=prerender_routes, dry_run=dry_run, use_rich=use_rich)
return app_module
diff --git a/reflex/utils/pyi_generator.py b/reflex/utils/pyi_generator.py
index 6aa6ac67dee..dd9fd50bb7f 100644
--- a/reflex/utils/pyi_generator.py
+++ b/reflex/utils/pyi_generator.py
@@ -547,7 +547,7 @@ def _generate_component_create_functiondef(
kwargs.extend(prop_kwargs)
def figure_out_return_type(annotation: Any):
- if inspect.isclass(annotation) and issubclass(annotation, inspect._empty):
+ if isinstance(annotation, type) and issubclass(annotation, inspect._empty):
return ast.Name(id="EventType[Any]")
if not isinstance(annotation, str) and get_origin(annotation) is tuple:
@@ -1181,7 +1181,7 @@ def _scan_file(self, module_path: Path) -> tuple[str, str] | None:
class_names = {
name: obj
for name, obj in vars(module).items()
- if inspect.isclass(obj)
+ if isinstance(obj, type)
and (
rx_types.safe_issubclass(obj, Component)
or rx_types.safe_issubclass(obj, SimpleNamespace)
diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py
index 5ae7ddd5ae1..fc351f1b8d1 100644
--- a/reflex/utils/serializers.py
+++ b/reflex/utils/serializers.py
@@ -187,7 +187,7 @@ def get_serializer(type_: type) -> Serializer | None:
# If the type is not registered, check if it is a subclass of a registered type.
for registered_type, serializer in reversed(SERIALIZERS.items()):
- if types._issubclass(type_, registered_type):
+ if issubclass(type_, registered_type):
return serializer
# If there is no serializer, return None.
@@ -211,7 +211,7 @@ def get_serializer_type(type_: type) -> type | None:
# If the type is not registered, check if it is a subclass of a registered type.
for registered_type, serializer in reversed(SERIALIZER_TYPES.items()):
- if types._issubclass(type_, registered_type):
+ if issubclass(type_, registered_type):
return serializer
# If there is no serializer, return None.
@@ -244,11 +244,11 @@ def can_serialize(type_: type, into_type: type | None = None) -> bool:
Returns:
Whether there is a serializer for the type.
"""
- return has_serializer(type_, into_type) or (
+ return (
isinstance(type_, type)
and dataclasses.is_dataclass(type_)
and (into_type is None or into_type is dict)
- )
+ ) or has_serializer(type_, into_type)
@serializer(to=str)
diff --git a/reflex/utils/types.py b/reflex/utils/types.py
index 6b83486cf7f..93e253b22c5 100644
--- a/reflex/utils/types.py
+++ b/reflex/utils/types.py
@@ -52,11 +52,8 @@
GenericType = type | _GenericAlias
# Valid state var types.
-JSONType = {str, int, float, bool}
-PrimitiveType = int | float | bool | str | list | dict | set | tuple
PrimitiveTypes = (int, float, bool, str, list, dict, set, tuple)
-StateVar = PrimitiveType | Base | None
-StateIterVar = list | set | tuple
+StateVarTypes = (*PrimitiveTypes, Base, type(None))
if TYPE_CHECKING:
from reflex.vars.base import Var
@@ -401,6 +398,8 @@ def get_field_type(cls: GenericType, field_name: str) -> GenericType | None:
Returns:
The type of the field, if it exists, else None.
"""
+ if (fields := getattr(cls, "_fields", None)) is not None and field_name in fields:
+ return fields[field_name].annotated_type
if (
hasattr(cls, "__fields__")
and field_name in cls.__fields__
@@ -857,8 +856,11 @@ def is_valid_var_type(type_: type) -> bool:
if is_union(type_):
return all(is_valid_var_type(arg) for arg in get_args(type_))
+
+ type_ = origin if (origin := get_origin(type_)) is not None else type_
+
return (
- _issubclass(type_, StateVar)
+ issubclass(type_, StateVarTypes)
or serializers.has_serializer(type_)
or dataclasses.is_dataclass(type_)
)
@@ -993,11 +995,6 @@ def validate_literal(key: str, value: Any, expected_type: type, comp_name: str):
raise ValueError(msg)
-# Store this here for performance.
-StateBases = get_base_class(StateVar)
-StateIterBases = get_base_class(StateIterVar)
-
-
def safe_issubclass(cls: Any, cls_check: Any | tuple[Any, ...]):
"""Check if a class is a subclass of another class. Returns False if internal error occurs.
diff --git a/reflex/vars/base.py b/reflex/vars/base.py
index 020d09aea05..554ddce899d 100644
--- a/reflex/vars/base.py
+++ b/reflex/vars/base.py
@@ -366,7 +366,7 @@ def can_use_in_object_var(cls: GenericType) -> bool:
if types.is_union(cls):
return all(can_use_in_object_var(t) for t in types.get_args(cls))
return (
- inspect.isclass(cls)
+ isinstance(cls, type)
and not safe_issubclass(cls, Var)
and serializers.can_serialize(cls, dict)
)
@@ -516,6 +516,17 @@ def _get_all_var_data(self) -> VarData | None:
"""
return self._var_data
+ def __deepcopy__(self, memo: dict[int, Any]) -> Self:
+ """Deepcopy the var.
+
+ Args:
+ memo: The memo dictionary to use for the deepcopy.
+
+ Returns:
+ A deepcopy of the var.
+ """
+ return self
+
def equals(self, other: Var) -> bool:
"""Check if two vars are equal.
@@ -795,7 +806,7 @@ def to(
if can_use_in_object_var(output):
return self.to(ObjectVar, output)
- if inspect.isclass(output):
+ if isinstance(output, type):
for var_subclass in _var_subclasses[::-1]:
if safe_issubclass(output, var_subclass.var_subclass):
current_var_type = self._var_type
@@ -891,7 +902,7 @@ def guess_type(self) -> Var:
args = get_args(var_type)
fixed_type = unionize(*(type(arg) for arg in args))
- if not inspect.isclass(fixed_type):
+ if not isinstance(fixed_type, type):
msg = f"Unsupported type {var_type} for guess_type."
raise TypeError(msg)
@@ -1153,18 +1164,6 @@ def _without_data(self):
"""
return dataclasses.replace(self, _var_data=None)
- def __get__(self, instance: Any, owner: Any):
- """Get the var.
-
- Args:
- instance: The instance to get the var from.
- owner: The owner of the var.
-
- Returns:
- The var.
- """
- return self
-
def _decode(self) -> Any:
"""Decode Var as a python value.
@@ -1409,7 +1408,7 @@ def __init_subclass__(cls, **kwargs):
bases = cls.__bases__
bases_normalized = [
- base if inspect.isclass(base) else get_origin(base) for base in bases
+ base if isinstance(base, type) else get_origin(base) for base in bases
]
possible_bases = [
diff --git a/reflex/vars/object.py b/reflex/vars/object.py
index 3d514926624..53f682e8ce6 100644
--- a/reflex/vars/object.py
+++ b/reflex/vars/object.py
@@ -6,7 +6,6 @@
import dataclasses
import typing
from collections.abc import Mapping
-from inspect import isclass
from typing import (
Any,
NoReturn,
@@ -328,7 +327,10 @@ def __getattr__(self, name: str) -> Var:
if (
is_typeddict(fixed_type)
- or (isclass(fixed_type) and not safe_issubclass(fixed_type, Mapping))
+ or (
+ isinstance(fixed_type, type)
+ and not safe_issubclass(fixed_type, Mapping)
+ )
or (fixed_type in types.UnionTypes)
):
attribute_type = get_attribute_access_type(var_type, name)
diff --git a/tests/units/plugins/test_sitemap.py b/tests/units/plugins/test_sitemap.py
index f79a4144da7..ba1af78807e 100644
--- a/tests/units/plugins/test_sitemap.py
+++ b/tests/units/plugins/test_sitemap.py
@@ -11,7 +11,8 @@
def test_generate_xml_empty_links():
"""Test generate_xml with an empty list of links."""
xml_output = generate_xml([])
- expected = '\n\n'
+ expected = """
+"""
assert xml_output == expected
@@ -19,13 +20,12 @@ def test_generate_xml_single_link_loc_only():
"""Test generate_xml with a single link having only loc."""
links: list[SitemapLink] = [{"loc": "https://example.com"}]
xml_output = generate_xml(links)
- expected = """
+ expected = """
https://example.com
-
-"""
+"""
assert xml_output == expected
@@ -47,7 +47,7 @@ def test_generate_xml_multiple_links_all_fields():
},
]
xml_output = generate_xml(links)
- expected = """
+ expected = """
https://example.com/page1
@@ -61,8 +61,7 @@ def test_generate_xml_multiple_links_all_fields():
2023-01-01T00:00:00
0.5
-
-"""
+"""
assert xml_output == expected