From c5657a68b37001d192cebb3bddd3df984ed2f57d Mon Sep 17 00:00:00 2001
From: Tyler Mathis <35553152+tsmathis@users.noreply.github.com>
Date: Thu, 22 Jan 2026 16:39:51 -0800
Subject: [PATCH] Revert "Merge pull request #490 from minhsueh/phonon_v2"
This reverts commit a83185c5d00ba437cb59da986d8501741f9370c2, reversing
changes made to d6efe1b9b0319bc32c503a5efec794d59a5a8bd3.
---
crystal_toolkit/components/phonon.py | 538 ++++++++-------------------
1 file changed, 146 insertions(+), 392 deletions(-)
diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py
index 5ef76d90..296a2947 100644
--- a/crystal_toolkit/components/phonon.py
+++ b/crystal_toolkit/components/phonon.py
@@ -1,41 +1,39 @@
from __future__ import annotations
import itertools
-from copy import deepcopy
from typing import TYPE_CHECKING, Any
import numpy as np
import plotly.graph_objects as go
from dash import dcc, html
-from dash.dependencies import Component, Input, Output, State
+from dash.dependencies import Component, Input, Output
from dash.exceptions import PreventUpdate
-from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene
-
-# crystal animation algo
-from pymatgen.analysis.graphs import StructureGraph
-from pymatgen.analysis.local_env import CrystalNN
+from dash_mp_components import CrystalToolkitScene
from pymatgen.ext.matproj import MPRester
from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine
from pymatgen.phonon.dos import CompletePhononDos
from pymatgen.phonon.plotter import PhononBSPlotter
-from pymatgen.transformations.standard_transformations import SupercellTransformation
from crystal_toolkit.core.mpcomponent import MPComponent
from crystal_toolkit.core.panelcomponent import PanelComponent
from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres
-from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list
+from crystal_toolkit.helpers.layouts import (
+ Column,
+ Columns,
+ Label,
+ MessageBody,
+ MessageContainer,
+ get_data_list,
+)
from crystal_toolkit.helpers.pretty_labels import pretty_labels
if TYPE_CHECKING:
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
from pymatgen.electronic_structure.dos import CompleteDos
-DISPLACE_COEF = [0, 1, 0, -1, 0]
-MARKER_COLOR = "red"
-MARKER_SIZE = 12
-MARKER_SHAPE = "x"
-MAX_MAGNITUDE = 300
-MIN_MAGNITUDE = 0
+# Author: Jason Munro, Janosh Riebesell
+# Contact: jmunro@lbl.gov, janosh@lbl.gov
+
# TODOs:
# - look for additional projection methods in phonon DOS (currently only atom
@@ -66,32 +64,26 @@ def __init__(
**kwargs,
)
- bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos(
- self.initial_data["default"]
- )
- self.create_store("bs-store", bs)
- self.create_store("bs", None)
- self.create_store("dos", None)
-
@property
def _sub_layouts(self) -> dict[str, Component]:
# defaults
state = {"label-select": "sc", "dos-select": "ap"}
- fig = PhononBandstructureAndDosComponent.get_figure(None, None)
+ bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos(
+ self.initial_data["default"]
+ )
+ fig = PhononBandstructureAndDosComponent.get_figure(bs, dos)
# Main plot
graph = dcc.Graph(
figure=fig,
config={"displayModeBar": False},
- responsive=False,
+ responsive=True,
id=self.id("ph-bsdos-graph"),
)
# Brillouin zone
- zone_scene = self.get_brillouin_zone_scene(None)
- zone = CrystalToolkitScene(
- data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone")
- )
+ zone_scene = self.get_brillouin_zone_scene(bs)
+ zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px")
# Hide by default if not loaded by mpid, switching between k-paths
# on-the-fly only supported for bandstructures retrieved from MP
@@ -113,11 +105,9 @@ def _sub_layouts(self) -> dict[str, Component]:
options=options,
)
],
- style=(
- {"width": "200px"}
- if show_path_options
- else {"maxWidth": "200", "display": "none"}
- ),
+ style={"width": "200px"}
+ if show_path_options
+ else {"maxWidth": "200", "display": "none"},
id=self.id("path-container"),
)
@@ -132,11 +122,9 @@ def _sub_layouts(self) -> dict[str, Component]:
options=options,
)
],
- style=(
- {"width": "200px"}
- if show_path_options
- else {"width": "200px", "display": "none"}
- ),
+ style={"width": "200px"}
+ if show_path_options
+ else {"width": "200px", "display": "none"},
id=self.id("label-container"),
)
@@ -150,82 +138,9 @@ def _sub_layouts(self) -> dict[str, Component]:
style={"width": "200px"},
)
- summary_dict = self._get_data_list_dict(None, None)
+ summary_dict = self._get_data_list_dict(bs, dos)
summary_table = get_data_list(summary_dict)
- # crystal visualization
-
- tip = html.H5(
- "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!",
- )
-
- crystal_animation = html.Div(
- CrystalToolkitAnimationScene(
- data={},
- sceneSize="500px",
- id=self.id("crystal-animation"),
- settings={"defaultZoom": 1.2},
- axisView="SW",
- showControls=False, # disable download for now
- ),
- style={"width": "60%"},
- )
-
- crystal_animation_controls = html.Div(
- [
- html.Br(),
- html.Div(tip, style={"textAlign": "center"}),
- html.Br(),
- html.H5("Control Panel", style={"textAlign": "center"}),
- html.H6("Supercell modification"),
- html.Br(),
- html.Div(
- [
- self.get_numerical_input(
- kwarg_label="scale-x",
- default=1,
- is_int=True,
- label="x",
- min=1,
- style={"width": "5rem"},
- ),
- self.get_numerical_input(
- kwarg_label="scale-y",
- default=1,
- is_int=True,
- label="y",
- min=1,
- style={"width": "5rem"},
- ),
- self.get_numerical_input(
- kwarg_label="scale-z",
- default=1,
- is_int=True,
- label="z",
- min=1,
- style={"width": "5rem"},
- ),
- html.Button(
- "Update",
- id=self.id("supercell-controls-btn"),
- style={"height": "40px"},
- ),
- ],
- style={"display": "flex"},
- ),
- html.Br(),
- html.Div(
- self.get_slider_input(
- kwarg_label="magnitude",
- default=0.5,
- step=0.01,
- domain=[0, 1],
- label="Vibration magnitude",
- )
- ),
- ],
- )
-
return {
"graph": graph,
"convention": convention,
@@ -233,31 +148,10 @@ def _sub_layouts(self) -> dict[str, Component]:
"label-select": label_select,
"zone": zone,
"table": summary_table,
- "crystal-animation": crystal_animation,
- "tip": tip,
- "crystal-animation-controls": crystal_animation_controls,
}
- def _get_animation_panel(self):
- sub_layouts = self._sub_layouts
- return Columns(
- [
- Column(
- [
- Columns(
- [
- sub_layouts["crystal-animation"],
- sub_layouts["crystal-animation-controls"],
- ]
- )
- ]
- ),
- ]
- )
-
def layout(self) -> html.Div:
sub_layouts = self._sub_layouts
- crystal_animation = self._get_animation_panel()
graph = Columns([Column([sub_layouts["graph"]])])
controls = Columns(
[
@@ -272,143 +166,11 @@ def layout(self) -> html.Div:
)
brillouin_zone = Columns(
[
- Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")),
+ Column([Label("Summary"), sub_layouts["table"]]),
Column([Label("Brillouin Zone"), sub_layouts["zone"]]),
]
)
-
- return html.Div([graph, crystal_animation, controls, brillouin_zone])
-
- @staticmethod
- def _get_eigendisplacement(
- ph_bs: BandStructureSymmLine,
- json_data: dict,
- band: int = 0,
- qpoint: int = 0,
- precision: int = 15,
- magnitude: int = MAX_MAGNITUDE / 2,
- total_repeat_cell_cnt: int = 1,
- ) -> dict:
- if not ph_bs or not json_data:
- return {}
-
- assert json_data["contents"][0]["name"] == "atoms"
- assert json_data["contents"][1]["name"] == "bonds"
- rdata = deepcopy(json_data)
-
- def calc_max_displacement(idx: int) -> list:
- """
- Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement.
-
- Parameters:
- idx (int): The atom index.
-
- Returns:
- list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement]
-
- This function extracts the real component of the atom's eigendisplacement,
- scales it by the specified magnitude, and returns the resulting vector.
- """
-
- # get the atom index
- assert total_repeat_cell_cnt != 0
-
- modified_idx = (
- int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx
- )
-
- return [
- round(complex(vec).real * magnitude, precision)
- for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx]
- ]
-
- def calc_animation_step(max_displacement: list, coef: int) -> list:
- """
- Calculate the displacement for an animation frame based on the given coefficient.
-
- Parameters:
- max_displacement (list): A list of maximum displacements along each axis,
- formatted as [x_max_displacement, y_max_displacement, z_max_displacement].
- coef (int): A coefficient indicating the motion direction.
- - 0: no movement
- - 1: forward movement
- - -1: backward movement
-
- Returns:
- list: The displacement vector [x_displacement, y_displacement, z_displacement].
-
- This function generates oscillatory motion by scaling the maximum displacement
- with the provided coefficient.
- """
- return [round(coef * md, precision) for md in max_displacement]
-
- # Compute per-frame atomic motion.
- # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates.
- contents0 = json_data["contents"][0]["contents"]
- for cidx, content in enumerate(contents0):
- max_displacement = calc_max_displacement(content["_meta"][0])
- rcontent = rdata["contents"][0]["contents"][cidx]
- # put animation frame to the given atom index
- rcontent["animate"] = [
- calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF
- ]
- rcontent["keyframes"] = list(range(len(DISPLACE_COEF)))
- rcontent["animateType"] = "displacement"
- # Compute per-frame bonding motion.
- # Explanation:
- # Each bond connects two atoms, `u` and `v`, represented as (u)----(v)
- # To model the bond motion, it is divided into two segments:
- # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v)
- # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid).
- # For each cylinder, displacements are assigned to the endpoints — for example,
- # the (u)--(mid) cylinder uses:
- # [
- # [u_x_displacement, u_y_displacement, u_z_displacement],
- # [mid_x_displacement, mid_y_displacement, mid_z_displacement]
- # ].
- contents1 = json_data["contents"][1]["contents"]
-
- for cidx, content in enumerate(contents1):
- bond_animation = []
- assert len(content["_meta"]) == len(content["positionPairs"])
-
- for atom_idx_pair in content["_meta"]:
- max_displacements = list(
- map(calc_max_displacement, atom_idx_pair)
- ) # max displacement for u and v
-
- u_to_middle_bond_animation = []
-
- for coef in DISPLACE_COEF:
- # Calculate the midpoint displacement between atom u and v for each animation frame.
- u_displacement, v_displacement = [
- np.array(calc_animation_step(max_displacement, coef))
- for max_displacement in max_displacements
- ]
- middle_end_displacement = np.add(u_displacement, v_displacement) / 2
-
- u_to_middle_bond_animation.append(
- [
- u_displacement, # u atom displacement
- [
- round(dis, precision) for dis in middle_end_displacement
- ], # middle point displacement
- ]
- )
-
- bond_animation.append(u_to_middle_bond_animation)
-
- rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation
- rdata["contents"][1]["contents"][cidx]["keyframes"] = list(
- range(len(DISPLACE_COEF))
- )
- rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement"
-
- # remove unused sense
- for i in range(2, 4):
- rdata["contents"][i]["visible"] = False
-
- return rdata
+ return html.Div([graph, controls, brillouin_zone])
@staticmethod
def _get_ph_bs_dos(
@@ -541,7 +303,6 @@ def get_ph_bandstructure_traces(bs, freq_range):
"line": {"color": "#1f77b4"},
"hoverinfo": "skip",
"name": "Total",
- "customdata": [[di, band_num] for di in range(len(x_dat))],
"hovertemplate": "%{y:.2f} THz",
"showlegend": False,
"xaxis": "x",
@@ -587,9 +348,6 @@ def get_ph_bandstructure_traces(bs, freq_range):
def _get_data_list_dict(
bs: PhononBandStructureSymmLine, dos: CompletePhononDos
) -> dict[str, str | bool | int]:
- if (not bs) and (not dos):
- return {}
-
bs_minpoint, bs_min_freq = bs.min_freq()
min_freq_report = (
f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}"
@@ -615,7 +373,7 @@ def _get_data_list_dict(
target="blank",
),
]
- ): ("Yes" if bs.has_nac else "No"),
+ ): "Yes" if bs.has_nac else "No",
"Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No",
"Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No",
"Min frequency": min_freq_report,
@@ -685,9 +443,14 @@ def get_figure(
ph_dos: CompletePhononDos | None = None,
freq_range: tuple[float | None, float | None] = (None, None),
) -> go.Figure:
+ if freq_range[0] is None:
+ freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1])
+
+ if freq_range[1] is None:
+ freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05)
+
if (not ph_dos) and (not ph_bs):
empty_plot_style = {
- "height": 500,
"xaxis": {"visible": False},
"yaxis": {"visible": False},
"paper_bgcolor": "rgba(0,0,0,0)",
@@ -696,12 +459,6 @@ def get_figure(
return go.Figure(layout=empty_plot_style)
- if freq_range[0] is None:
- freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1])
-
- if freq_range[1] is None:
- freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05)
-
if ph_bs:
(
bs_traces,
@@ -798,7 +555,7 @@ def get_figure(
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(230,230,230,230)",
margin=dict(l=60, b=50, t=50, pad=0, r=30),
- clickmode="event+select",
+ # clickmode="event+select"
)
figure = {"data": bs_traces + dos_traces, "layout": layout}
@@ -823,57 +580,124 @@ def get_figure(
def generate_callbacks(self, app, cache) -> None:
@app.callback(
Output(self.id("ph-bsdos-graph"), "figure"),
- Output(self.id("zone"), "data"),
- Output(self.id("table"), "children"),
- Input(self.id("ph_bs"), "data"),
- Input(self.id("ph_dos"), "data"),
- Input(self.id("ph-bsdos-graph"), "clickData"),
+ Input(self.id("traces"), "data"),
)
- def update_graph(bs, dos, nclick):
- if isinstance(bs, dict):
- bs = PhononBandStructureSymmLine.from_dict(bs)
- if isinstance(dos, dict):
- dos = CompletePhononDos.from_dict(dos)
+ def update_graph(traces):
+ if traces == "error":
+ msg_body = MessageBody(
+ dcc.Markdown(
+ "Band structure and density of states not available for this selection."
+ )
+ )
+ return (MessageContainer([msg_body], kind="warning"),)
+
+ if traces is None:
+ raise PreventUpdate
+
+ bs, dos = self._get_ph_bs_dos(self.initial_data["default"])
figure = self.get_figure(bs, dos)
+ return dcc.Graph(
+ figure=figure, config={"displayModeBar": False}, responsive=True
+ )
- # remove marker if there is one
- figure["data"] = [
- t for t in figure["data"] if t.get("name") != "click-marker"
- ]
+ @app.callback(
+ Output(self.id("label-select"), "value"),
+ Output(self.id("label-container"), "style"),
+ Input(self.id("mpid"), "data"),
+ Input(self.id("path-convention"), "value"),
+ )
+ def update_label_select(mpid, path_convention):
+ if not mpid:
+ raise PreventUpdate
+ label_value = path_convention
+ label_style = {"maxWidth": "200"}
- x_click = nclick["points"][0]["x"] if nclick else 0
- y_click = nclick["points"][0]["y"] if nclick else 0
- pt = nclick["points"][0] if nclick else {}
+ return label_value, label_style
- qpoint, band_num = pt.get("customdata", [0, 0])
+ @app.callback(
+ Output(self.id("dos-select"), "options"),
+ Output(self.id("path-convention"), "options"),
+ Output(self.id("path-container"), "style"),
+ Input(self.id("elements"), "data"),
+ Input(self.id("mpid"), "data"),
+ )
+ def update_select(elements, mpid):
+ if elements is None:
+ raise PreventUpdate
+ if not mpid:
+ dos_options = (
+ [{"label": "Element Projected", "value": "ap"}]
+ + [{"label": "Orbital Projected - Total", "value": "op"}]
+ + [
+ {
+ "label": "Orbital Projected - " + str(ele_label),
+ "value": "orb" + str(ele_label),
+ }
+ for ele_label in elements
+ ]
+ )
- figure["data"].append(
- {
- "type": "scatter",
- "mode": "markers",
- "x": [x_click],
- "y": [y_click],
- "marker": {
- "color": MARKER_COLOR,
- "size": MARKER_SIZE,
- "symbol": MARKER_SHAPE,
- },
- "name": "click-marker",
- "showlegend": False,
- "customdata": [[qpoint, band_num]],
- "hovertemplate": (
- "band: %{customdata[1]}
q-point: %{customdata[0]}
"
- ),
- }
+ path_options = [{"label": "N/A", "value": "sc"}]
+ path_style = {"maxWidth": "200", "display": "none"}
+
+ return dos_options, path_options, path_style
+ dos_options = (
+ [{"label": "Element Projected", "value": "ap"}]
+ + [{"label": "Orbital Projected - Total", "value": "op"}]
+ + [
+ {
+ "label": "Orbital Projected - " + str(ele_label),
+ "value": "orb" + str(ele_label),
+ }
+ for ele_label in elements
+ ]
)
- zone_scene = self.get_brillouin_zone_scene(bs)
+ path_options = [
+ {"label": "Setyawan-Curtarolo", "value": "sc"},
+ {"label": "Latimer-Munro", "value": "lm"},
+ {"label": "Hinuma et al.", "value": "hin"},
+ ]
- summary_dict = self._get_data_list_dict(bs, dos)
- summary_table = get_data_list(summary_dict)
+ path_style = {"maxWidth": "200"}
- return figure, zone_scene.to_json(), summary_table
+ return dos_options, path_options, path_style
+
+ @app.callback(
+ Output(self.id("traces"), "data"),
+ Output(self.id("elements"), "data"),
+ Input(self.id(), "data"),
+ Input(self.id("path-convention"), "value"),
+ Input(self.id("dos-select"), "value"),
+ Input(self.id("label-select"), "value"),
+ )
+ def bs_dos_data(data, dos_select, label_select):
+ # Obtain bands to plot over and generate traces for bs data:
+ energy_window = (-6.0, 10.0)
+
+ traces = []
+
+ bsml, density_of_states = self._get_ph_bs_dos(data)
+
+ if self.bandstructure_symm_line:
+ bs_traces = self.get_ph_bandstructure_traces(
+ bsml, freq_range=energy_window
+ )
+ traces.append(bs_traces)
+
+ if self.density_of_states:
+ dos_traces = self.get_ph_dos_traces(
+ density_of_states, freq_range=energy_window
+ )
+ traces.append(dos_traces)
+
+ # traces = [bs_traces, dos_traces, bs_data]
+
+ # TODO: not tested if this is correct way to get element list
+ elements = list(map(str, density_of_states.get_element_dos()))
+
+ return traces, elements
@app.callback(
Output(self.id("brillouin-zone"), "data"),
@@ -887,78 +711,8 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select):
# TODO: figure out what to return (CSS?) to highlight BZ edge/point
return
- @app.callback(
- Output(self.id("crystal-animation"), "data"),
- Input(self.id("ph-bsdos-graph"), "clickData"),
- Input(self.id("ph_bs"), "data"),
- Input(self.id("supercell-controls-btn"), "n_clicks"),
- Input(self.get_kwarg_id("magnitude"), "value"),
- State(self.get_kwarg_id("scale-x"), "value"),
- State(self.get_kwarg_id("scale-y"), "value"),
- State(self.get_kwarg_id("scale-z"), "value"),
- # prevent_initial_call=True
- )
- def update_crystal_animation(
- cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z
- ):
- # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields,
- # ensuring updates occur only after the `supercell-controls-btn`` is clicked.
-
- if not bs:
- raise PreventUpdate
-
- # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values.
- # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value,
- # this approach provides better clarity and readability.
- kwargs = self.reconstruct_kwargs_from_state()
- magnitude_fraction = kwargs.get("magnitude")
- scale_x = kwargs.get("scale-x")
- scale_y = kwargs.get("scale-y")
- scale_z = kwargs.get("scale-z")
-
- if isinstance(bs, dict):
- bs = PhononBandStructureSymmLine.from_dict(bs)
-
- struct = bs.structure
- total_repeat_cell_cnt = 1
- # update structure if the controls got triggered
- if sueprcell_update:
- total_repeat_cell_cnt = scale_x * scale_y * scale_z
-
- # create supercell
- trans = SupercellTransformation(
- ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z))
- )
- struct = trans.apply_transformation(struct)
-
- struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN())
- scene = struc_graph.get_scene(
- draw_image_atoms=False,
- bonded_sites_outside_unit_cell=False,
- site_get_scene_kwargs={"retain_atom_idx": True},
- )
- json_data = scene.to_json()
-
- qpoint = 0
- band_num = 0
-
- if cd and cd.get("points"):
- pt = cd["points"][0]
- qpoint, band_num = pt.get("customdata", [0, 0])
-
- # magnitude
- magnitude = (
- MAX_MAGNITUDE - MIN_MAGNITUDE
- ) * magnitude_fraction + MIN_MAGNITUDE
-
- return PhononBandstructureAndDosComponent._get_eigendisplacement(
- ph_bs=bs,
- json_data=json_data,
- band=band_num,
- qpoint=qpoint,
- total_repeat_cell_cnt=total_repeat_cell_cnt,
- magnitude=magnitude,
- )
+ # TODO: figure out what to return (CSS?) to highlight BZ edge/point
+ return
class PhononBandstructureAndDosPanelComponent(PanelComponent):