diff --git a/frontend/src/components/editor/Output.tsx b/frontend/src/components/editor/Output.tsx index c1a584683b1..2b3f6296135 100644 --- a/frontend/src/components/editor/Output.tsx +++ b/frontend/src/components/editor/Output.tsx @@ -141,10 +141,17 @@ export const OutputRenderer: React.FC<{ case "image/bmp": case "image/gif": case "image/jpeg": + case "image/svg+xml": invariant( typeof data === "string", `Expected string data for mime=${mimetype}. Got ${typeof data}`, ); + if ( + mimetype === "image/svg+xml" && + !data.startsWith("data:image/svg+xml;base64,") + ) { + return renderHTML({ html: data, alwaysSanitizeHtml: true }); + } return ( ); - case "image/svg+xml": - invariant( - typeof data === "string", - `Expected string data for mime=${mimetype}. Got ${typeof data}`, - ); - return renderHTML({ html: data, alwaysSanitizeHtml: true }); case "video/mp4": case "video/mpeg": diff --git a/frontend/src/components/editor/__tests__/Output.test.tsx b/frontend/src/components/editor/__tests__/Output.test.tsx index 0d45379cc67..deae467f152 100644 --- a/frontend/src/components/editor/__tests__/Output.test.tsx +++ b/frontend/src/components/editor/__tests__/Output.test.tsx @@ -64,3 +64,62 @@ describe("OutputRenderer renderFallback prop", () => { ).toBeInTheDocument(); }); }); + +describe("OutputRenderer image and SVG rendering", () => { + const plainSvgString = + ''; + const base64SvgDataUrl = + "data:image/svg+xml;base64,PHN2Zz48cmVjdCB4PSIwIiB5PSIw"; + const base64PngDataUrl = + "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB"; + + it("should render plain SVG string via renderHTML", () => { + const { container } = render( + , + ); + const svgElement = container.querySelector("svg"); + expect(svgElement).not.toBeNull(); + const rectElement = svgElement!.querySelector("rect"); + expect(rectElement).not.toBeNull(); + const imgElement = container.querySelector("img"); + expect(imgElement).toBeNull(); + }); + + it("should render Base64 SVG data URL via ImageOutput", () => { + const { container } = render( + , + ); + const imgElement = container.querySelector("img"); + expect(imgElement).not.toBeNull(); + expect(imgElement).toHaveAttribute("src", base64SvgDataUrl); + const svgElement = container.querySelector("svg"); + expect(svgElement).toBeNull(); + }); + + it("should render Base64 PNG data URL via ImageOutput", () => { + const { container } = render( + , + ); + const imgElement = container.querySelector("img"); + expect(imgElement).not.toBeNull(); + expect(imgElement).toHaveAttribute("src", base64PngDataUrl); + }); +}); diff --git a/marimo/_convert/ipynb/from_ir.py b/marimo/_convert/ipynb/from_ir.py index e4970d97f42..cc2643e4546 100644 --- a/marimo/_convert/ipynb/from_ir.py +++ b/marimo/_convert/ipynb/from_ir.py @@ -3,6 +3,7 @@ from __future__ import annotations +import base64 import io import json import re @@ -191,6 +192,10 @@ def _add_marimo_metadata( def _maybe_extract_dataurl(data: Any) -> Any: + if isinstance(data, str) and data.startswith("data:image/svg+xml;base64,"): + # Decode SVG from base64 to plain text XML + payload = data[len("data:image/svg+xml;base64,") :] + return base64.b64decode(payload).decode() if ( isinstance(data, str) and data.startswith("data:") diff --git a/marimo/_output/mpl.py b/marimo/_output/mpl.py index 0b5b8cae746..650fcb4303d 100644 --- a/marimo/_output/mpl.py +++ b/marimo/_output/mpl.py @@ -66,11 +66,20 @@ def _render_figure_mimebundle( fig: Matplotlib figure canvas to render Returns: - Tuple of (mimetype, json_data) where json_data is a mimebundle - containing the PNG data URL and display metadata + Tuple of (mimetype, data). If `matplotlib.rcParams["savefig.format"]` is 'svg', + mimetype is 'image/svg+xml' and data is the Base64-encoded SVG data URL. + Otherwise, mimetype is 'application/vnd.marimo+mimebundle' and data is a JSON string + representing a mimebundle containing the PNG data URL and display metadata. """ buf = io.BytesIO() + if plt.rcParams["savefig.format"] == "svg": + fig.figure.savefig(buf, format="svg", bbox_inches="tight") # type: ignore[attr-defined] + svg_bytes = buf.getvalue() + plot_bytes = base64.b64encode(svg_bytes) + data_url = build_data_url(mimetype="image/svg+xml", data=plot_bytes) + return "image/svg+xml", data_url + # Get current DPI and double it for retina display (like Jupyter) original_dpi = fig.figure.dpi # type: ignore[attr-defined] retina_dpi = original_dpi * 2 diff --git a/tests/_convert/ipynb/test_from_ir.py b/tests/_convert/ipynb/test_from_ir.py index ef33bbd5c3e..52f5b9f15cc 100644 --- a/tests/_convert/ipynb/test_from_ir.py +++ b/tests/_convert/ipynb/test_from_ir.py @@ -88,6 +88,11 @@ def __(): "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAUA", "iVBORw0KGgoAAAANSUhEUgAAAAUA", ), + # SVG string from Base64 data URL + ( + "data:image/svg+xml;base64,PHN2Zz48L3N2Zz4=", + "", + ), # Non-data-URL string passes through ("hello world", "hello world"), # Dict passes through @@ -99,6 +104,7 @@ def __(): ], ids=[ "base64_data_url", + "svg_string_from_base64_data_url", "regular_string", "dict_passthrough", "int_passthrough", diff --git a/tests/_output/formatters/test_matplotlib.py b/tests/_output/formatters/test_matplotlib.py index 891374b3a67..fe2849092ae 100644 --- a/tests/_output/formatters/test_matplotlib.py +++ b/tests/_output/formatters/test_matplotlib.py @@ -243,3 +243,40 @@ async def test_matplotlib_backwards_compatibility( assert mime_type == "application/vnd.marimo+mimebundle" mimebundle = json.loads(data) assert "image/png" in mimebundle + + +@pytest.mark.skipif(not HAS_MPL, reason="optional dependencies not installed") +async def test_matplotlib_svg_rendering( + executing_kernel: Kernel, exec_req: ExecReqProvider +) -> None: + """Test that matplotlib figures are rendered in SVG format.""" + from marimo._output.formatters.formatters import register_formatters + + register_formatters(theme="light") + + await executing_kernel.run( + [ + exec_req.get( + """ + import matplotlib.pyplot as plt + + fmt = plt.rcParams["savefig.format"] + plt.rcParams["savefig.format"] = "svg" + + # Create a simple figure + fig, ax = plt.subplots(figsize=(4, 3)) + ax.plot([1, 2, 3], [1, 2, 3]) + result = fig._mime_() + + plt.rcParams["savefig.format"] = fmt + """ + ) + ] + ) + + # Get the formatted result from kernel globals + mime_type, data = executing_kernel.globals["result"] + + assert mime_type == "image/svg+xml" + assert isinstance(data, str) + assert data.startswith("data:image/svg+xml;base64,PD94")