diff --git a/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js b/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js index 72d1f69ebfa..46c4471a2e1 100644 --- a/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js +++ b/reflex/.templates/web/components/reflex/radix_themes_color_mode_provider.js @@ -1,5 +1,5 @@ import { useTheme } from "$/utils/react-theme"; -import { createElement } from "react"; +import { createElement, useEffect } from "react"; import { ColorModeContext, defaultColorMode } from "$/utils/context"; export default function RadixThemesColorModeProvider({ children }) { @@ -20,6 +20,16 @@ export default function RadixThemesColorModeProvider({ children }) { setTheme(mode); }; + useEffect(() => { + const radixRoot = document.querySelector( + '.radix-themes[data-is-root-theme="true"]', + ); + if (radixRoot) { + radixRoot.classList.remove("light", "dark"); + radixRoot.classList.add(resolvedTheme); + } + }, [resolvedTheme]); + return createElement( ColorModeContext.Provider, { diff --git a/reflex/compiler/templates.py b/reflex/compiler/templates.py index bf525b2fa51..2cccdb41eb4 100644 --- a/reflex/compiler/templates.py +++ b/reflex/compiler/templates.py @@ -200,7 +200,6 @@ def app_root_template( ) return f""" -import reflexGlobalStyles from '$/styles/__reflex_global_styles.css?url'; {imports_str} {dynamic_imports_str} import {{ EventLoopProvider, StateProvider, defaultColorMode }} from "$/utils/context"; @@ -211,10 +210,6 @@ def app_root_template( {custom_code_str} -export const links = () => [ - {{ rel: 'stylesheet', href: reflexGlobalStyles, type: 'text/css' }} -]; - function AppWrap({{children}}) {{ {_render_hooks(hooks)} return ({_RenderUtils.render(render)}) diff --git a/reflex/compiler/utils.py b/reflex/compiler/utils.py index 3a8a33a5a9c..9a748ca1f7e 100644 --- a/reflex/compiler/utils.py +++ b/reflex/compiler/utils.py @@ -16,7 +16,7 @@ from reflex.components.base.document import Links, ScrollRestoration from reflex.components.base.document import Meta as ReactMeta from reflex.components.component import Component, ComponentStyle, CustomComponent -from reflex.components.el.elements.metadata import Head, Meta, Title +from reflex.components.el.elements.metadata import Head, Link, Meta, Title from reflex.components.el.elements.other import Html from reflex.components.el.elements.sectioning import Body from reflex.constants.state import FIELD_MARKER @@ -26,7 +26,7 @@ from reflex.utils import format, imports, path_ops from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.utils.prerequisites import get_web_dir -from reflex.vars.base import Field, Var +from reflex.vars.base import Field, Var, VarData # To re-export this function. merge_imports = imports.merge_imports @@ -382,6 +382,20 @@ def create_document_root( # Always include the framework meta and link tags. always_head_components = [ ReactMeta.create(), + Link.create( + rel="stylesheet", + type="text/css", + href=Var( + "reflexGlobalStyles", + _var_data=VarData( + imports={ + "$/styles/__reflex_global_styles.css?url": [ + ImportVar(tag="reflexGlobalStyles", is_default=True) + ] + } + ), + ), + ), Links.create(), ] maybe_head_components = [] diff --git a/tests/units/compiler/test_compiler.py b/tests/units/compiler/test_compiler.py index 920249d216b..113b502b74d 100644 --- a/tests/units/compiler/test_compiler.py +++ b/tests/units/compiler/test_compiler.py @@ -8,6 +8,7 @@ from reflex import constants from reflex.compiler import compiler, utils from reflex.components.base import document +from reflex.components.el.elements.metadata import Link from reflex.constants.compiler import PageNames from reflex.utils.imports import ImportVar, ParsedImportDict from reflex.vars.base import Var @@ -364,7 +365,7 @@ def test_create_document_root(): assert isinstance(lang, LiteralStringVar) assert lang.equals(Var.create("en")) # No children in head. - assert len(root.children[0].children) == 5 + assert len(root.children[0].children) == 6 assert isinstance(root.children[0].children[1], utils.Meta) char_set = root.children[0].children[1].char_set # pyright: ignore [reportAttributeAccessIssue] assert isinstance(char_set, LiteralStringVar) @@ -374,7 +375,8 @@ def test_create_document_root(): assert isinstance(name, LiteralStringVar) assert name.equals(Var.create("viewport")) assert isinstance(root.children[0].children[3], document.Meta) - assert isinstance(root.children[0].children[4], document.Links) + assert isinstance(root.children[0].children[4], Link) + assert isinstance(root.children[0].children[5], document.Links) def test_create_document_root_with_scripts(): @@ -389,9 +391,18 @@ def test_create_document_root_with_scripts(): html_custom_attrs={"project": "reflex"}, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 7 + assert len(root.children[0].children) == 8 names = [c.tag for c in root.children[0].children] - assert names == ["script", "Scripts", "Scripts", "meta", "meta", "Meta", "Links"] + assert names == [ + "script", + "Scripts", + "Scripts", + "meta", + "meta", + "Meta", + "link", + "Links", + ] lang = root.lang # pyright: ignore [reportAttributeAccessIssue] assert isinstance(lang, LiteralStringVar) assert lang.equals(Var.create("rx")) @@ -408,9 +419,9 @@ def test_create_document_root_with_meta_char_set(): head_components=comps, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 5 + assert len(root.children[0].children) == 6 names = [c.tag for c in root.children[0].children] - assert names == ["script", "meta", "meta", "Meta", "Links"] + assert names == ["script", "meta", "meta", "Meta", "link", "Links"] assert str(root.children[0].children[1].char_set) == '"cp1252"' # pyright: ignore [reportAttributeAccessIssue] @@ -424,9 +435,9 @@ def test_create_document_root_with_meta_viewport(): head_components=comps, ) assert isinstance(root, utils.Html) - assert len(root.children[0].children) == 6 + assert len(root.children[0].children) == 7 names = [c.tag for c in root.children[0].children] - assert names == ["script", "meta", "meta", "meta", "Meta", "Links"] + assert names == ["script", "meta", "meta", "meta", "Meta", "link", "Links"] assert str(root.children[0].children[1].http_equiv) == '"refresh"' # pyright: ignore [reportAttributeAccessIssue] assert str(root.children[0].children[2].name) == '"viewport"' # pyright: ignore [reportAttributeAccessIssue] assert str(root.children[0].children[2].content) == '"foo"' # pyright: ignore [reportAttributeAccessIssue]