diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 5ef76d90..296a2947 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -1,41 +1,39 @@ from __future__ import annotations import itertools -from copy import deepcopy from typing import TYPE_CHECKING, Any import numpy as np import plotly.graph_objects as go from dash import dcc, html -from dash.dependencies import Component, Input, Output, State +from dash.dependencies import Component, Input, Output from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene - -# crystal animation algo -from pymatgen.analysis.graphs import StructureGraph -from pymatgen.analysis.local_env import CrystalNN +from dash_mp_components import CrystalToolkitScene from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.phonon.plotter import PhononBSPlotter -from pymatgen.transformations.standard_transformations import SupercellTransformation from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres -from crystal_toolkit.helpers.layouts import Column, Columns, Label, get_data_list +from crystal_toolkit.helpers.layouts import ( + Column, + Columns, + Label, + MessageBody, + MessageContainer, + get_data_list, +) from crystal_toolkit.helpers.pretty_labels import pretty_labels if TYPE_CHECKING: from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine from pymatgen.electronic_structure.dos import CompleteDos -DISPLACE_COEF = [0, 1, 0, -1, 0] -MARKER_COLOR = "red" -MARKER_SIZE = 12 -MARKER_SHAPE = "x" -MAX_MAGNITUDE = 300 -MIN_MAGNITUDE = 0 +# Author: Jason Munro, Janosh Riebesell +# Contact: jmunro@lbl.gov, janosh@lbl.gov + # TODOs: # - look for additional projection methods in phonon DOS (currently only atom @@ -66,32 +64,26 @@ def __init__( **kwargs, ) - bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - self.create_store("bs-store", bs) - self.create_store("bs", None) - self.create_store("dos", None) - @property def _sub_layouts(self) -> dict[str, Component]: # defaults state = {"label-select": "sc", "dos-select": "ap"} - fig = PhononBandstructureAndDosComponent.get_figure(None, None) + bs, dos = PhononBandstructureAndDosComponent._get_ph_bs_dos( + self.initial_data["default"] + ) + fig = PhononBandstructureAndDosComponent.get_figure(bs, dos) # Main plot graph = dcc.Graph( figure=fig, config={"displayModeBar": False}, - responsive=False, + responsive=True, id=self.id("ph-bsdos-graph"), ) # Brillouin zone - zone_scene = self.get_brillouin_zone_scene(None) - zone = CrystalToolkitScene( - data=zone_scene.to_json(), sceneSize="500px", id=self.id("zone") - ) + zone_scene = self.get_brillouin_zone_scene(bs) + zone = CrystalToolkitScene(data=zone_scene.to_json(), sceneSize="500px") # Hide by default if not loaded by mpid, switching between k-paths # on-the-fly only supported for bandstructures retrieved from MP @@ -113,11 +105,9 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style=( - {"width": "200px"} - if show_path_options - else {"maxWidth": "200", "display": "none"} - ), + style={"width": "200px"} + if show_path_options + else {"maxWidth": "200", "display": "none"}, id=self.id("path-container"), ) @@ -132,11 +122,9 @@ def _sub_layouts(self) -> dict[str, Component]: options=options, ) ], - style=( - {"width": "200px"} - if show_path_options - else {"width": "200px", "display": "none"} - ), + style={"width": "200px"} + if show_path_options + else {"width": "200px", "display": "none"}, id=self.id("label-container"), ) @@ -150,82 +138,9 @@ def _sub_layouts(self) -> dict[str, Component]: style={"width": "200px"}, ) - summary_dict = self._get_data_list_dict(None, None) + summary_dict = self._get_data_list_dict(bs, dos) summary_table = get_data_list(summary_dict) - # crystal visualization - - tip = html.H5( - "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!", - ) - - crystal_animation = html.Div( - CrystalToolkitAnimationScene( - data={}, - sceneSize="500px", - id=self.id("crystal-animation"), - settings={"defaultZoom": 1.2}, - axisView="SW", - showControls=False, # disable download for now - ), - style={"width": "60%"}, - ) - - crystal_animation_controls = html.Div( - [ - html.Br(), - html.Div(tip, style={"textAlign": "center"}), - html.Br(), - html.H5("Control Panel", style={"textAlign": "center"}), - html.H6("Supercell modification"), - html.Br(), - html.Div( - [ - self.get_numerical_input( - kwarg_label="scale-x", - default=1, - is_int=True, - label="x", - min=1, - style={"width": "5rem"}, - ), - self.get_numerical_input( - kwarg_label="scale-y", - default=1, - is_int=True, - label="y", - min=1, - style={"width": "5rem"}, - ), - self.get_numerical_input( - kwarg_label="scale-z", - default=1, - is_int=True, - label="z", - min=1, - style={"width": "5rem"}, - ), - html.Button( - "Update", - id=self.id("supercell-controls-btn"), - style={"height": "40px"}, - ), - ], - style={"display": "flex"}, - ), - html.Br(), - html.Div( - self.get_slider_input( - kwarg_label="magnitude", - default=0.5, - step=0.01, - domain=[0, 1], - label="Vibration magnitude", - ) - ), - ], - ) - return { "graph": graph, "convention": convention, @@ -233,31 +148,10 @@ def _sub_layouts(self) -> dict[str, Component]: "label-select": label_select, "zone": zone, "table": summary_table, - "crystal-animation": crystal_animation, - "tip": tip, - "crystal-animation-controls": crystal_animation_controls, } - def _get_animation_panel(self): - sub_layouts = self._sub_layouts - return Columns( - [ - Column( - [ - Columns( - [ - sub_layouts["crystal-animation"], - sub_layouts["crystal-animation-controls"], - ] - ) - ] - ), - ] - ) - def layout(self) -> html.Div: sub_layouts = self._sub_layouts - crystal_animation = self._get_animation_panel() graph = Columns([Column([sub_layouts["graph"]])]) controls = Columns( [ @@ -272,143 +166,11 @@ def layout(self) -> html.Div: ) brillouin_zone = Columns( [ - Column([Label("Summary"), sub_layouts["table"]], id=self.id("table")), + Column([Label("Summary"), sub_layouts["table"]]), Column([Label("Brillouin Zone"), sub_layouts["zone"]]), ] ) - - return html.Div([graph, crystal_animation, controls, brillouin_zone]) - - @staticmethod - def _get_eigendisplacement( - ph_bs: BandStructureSymmLine, - json_data: dict, - band: int = 0, - qpoint: int = 0, - precision: int = 15, - magnitude: int = MAX_MAGNITUDE / 2, - total_repeat_cell_cnt: int = 1, - ) -> dict: - if not ph_bs or not json_data: - return {} - - assert json_data["contents"][0]["name"] == "atoms" - assert json_data["contents"][1]["name"] == "bonds" - rdata = deepcopy(json_data) - - def calc_max_displacement(idx: int) -> list: - """ - Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement. - - Parameters: - idx (int): The atom index. - - Returns: - list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement] - - This function extracts the real component of the atom's eigendisplacement, - scales it by the specified magnitude, and returns the resulting vector. - """ - - # get the atom index - assert total_repeat_cell_cnt != 0 - - modified_idx = ( - int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx - ) - - return [ - round(complex(vec).real * magnitude, precision) - for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx] - ] - - def calc_animation_step(max_displacement: list, coef: int) -> list: - """ - Calculate the displacement for an animation frame based on the given coefficient. - - Parameters: - max_displacement (list): A list of maximum displacements along each axis, - formatted as [x_max_displacement, y_max_displacement, z_max_displacement]. - coef (int): A coefficient indicating the motion direction. - - 0: no movement - - 1: forward movement - - -1: backward movement - - Returns: - list: The displacement vector [x_displacement, y_displacement, z_displacement]. - - This function generates oscillatory motion by scaling the maximum displacement - with the provided coefficient. - """ - return [round(coef * md, precision) for md in max_displacement] - - # Compute per-frame atomic motion. - # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates. - contents0 = json_data["contents"][0]["contents"] - for cidx, content in enumerate(contents0): - max_displacement = calc_max_displacement(content["_meta"][0]) - rcontent = rdata["contents"][0]["contents"][cidx] - # put animation frame to the given atom index - rcontent["animate"] = [ - calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF - ] - rcontent["keyframes"] = list(range(len(DISPLACE_COEF))) - rcontent["animateType"] = "displacement" - # Compute per-frame bonding motion. - # Explanation: - # Each bond connects two atoms, `u` and `v`, represented as (u)----(v) - # To model the bond motion, it is divided into two segments: - # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v) - # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid). - # For each cylinder, displacements are assigned to the endpoints — for example, - # the (u)--(mid) cylinder uses: - # [ - # [u_x_displacement, u_y_displacement, u_z_displacement], - # [mid_x_displacement, mid_y_displacement, mid_z_displacement] - # ]. - contents1 = json_data["contents"][1]["contents"] - - for cidx, content in enumerate(contents1): - bond_animation = [] - assert len(content["_meta"]) == len(content["positionPairs"]) - - for atom_idx_pair in content["_meta"]: - max_displacements = list( - map(calc_max_displacement, atom_idx_pair) - ) # max displacement for u and v - - u_to_middle_bond_animation = [] - - for coef in DISPLACE_COEF: - # Calculate the midpoint displacement between atom u and v for each animation frame. - u_displacement, v_displacement = [ - np.array(calc_animation_step(max_displacement, coef)) - for max_displacement in max_displacements - ] - middle_end_displacement = np.add(u_displacement, v_displacement) / 2 - - u_to_middle_bond_animation.append( - [ - u_displacement, # u atom displacement - [ - round(dis, precision) for dis in middle_end_displacement - ], # middle point displacement - ] - ) - - bond_animation.append(u_to_middle_bond_animation) - - rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation - rdata["contents"][1]["contents"][cidx]["keyframes"] = list( - range(len(DISPLACE_COEF)) - ) - rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" - - # remove unused sense - for i in range(2, 4): - rdata["contents"][i]["visible"] = False - - return rdata + return html.Div([graph, controls, brillouin_zone]) @staticmethod def _get_ph_bs_dos( @@ -541,7 +303,6 @@ def get_ph_bandstructure_traces(bs, freq_range): "line": {"color": "#1f77b4"}, "hoverinfo": "skip", "name": "Total", - "customdata": [[di, band_num] for di in range(len(x_dat))], "hovertemplate": "%{y:.2f} THz", "showlegend": False, "xaxis": "x", @@ -587,9 +348,6 @@ def get_ph_bandstructure_traces(bs, freq_range): def _get_data_list_dict( bs: PhononBandStructureSymmLine, dos: CompletePhononDos ) -> dict[str, str | bool | int]: - if (not bs) and (not dos): - return {} - bs_minpoint, bs_min_freq = bs.min_freq() min_freq_report = ( f"{bs_min_freq:.2f} THz at frac. coords. {bs_minpoint.frac_coords}" @@ -615,7 +373,7 @@ def _get_data_list_dict( target="blank", ), ] - ): ("Yes" if bs.has_nac else "No"), + ): "Yes" if bs.has_nac else "No", "Has imaginary frequencies": "Yes" if bs.has_imaginary_freq() else "No", "Has eigen-displacements": "Yes" if bs.has_eigendisplacements else "No", "Min frequency": min_freq_report, @@ -685,9 +443,14 @@ def get_figure( ph_dos: CompletePhononDos | None = None, freq_range: tuple[float | None, float | None] = (None, None), ) -> go.Figure: + if freq_range[0] is None: + freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) + + if freq_range[1] is None: + freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) + if (not ph_dos) and (not ph_bs): empty_plot_style = { - "height": 500, "xaxis": {"visible": False}, "yaxis": {"visible": False}, "paper_bgcolor": "rgba(0,0,0,0)", @@ -696,12 +459,6 @@ def get_figure( return go.Figure(layout=empty_plot_style) - if freq_range[0] is None: - freq_range = (np.min(ph_bs.bands) * 1.05, freq_range[1]) - - if freq_range[1] is None: - freq_range = (freq_range[0], np.max(ph_bs.bands) * 1.05) - if ph_bs: ( bs_traces, @@ -798,7 +555,7 @@ def get_figure( paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(230,230,230,230)", margin=dict(l=60, b=50, t=50, pad=0, r=30), - clickmode="event+select", + # clickmode="event+select" ) figure = {"data": bs_traces + dos_traces, "layout": layout} @@ -823,57 +580,124 @@ def get_figure( def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), - Output(self.id("zone"), "data"), - Output(self.id("table"), "children"), - Input(self.id("ph_bs"), "data"), - Input(self.id("ph_dos"), "data"), - Input(self.id("ph-bsdos-graph"), "clickData"), + Input(self.id("traces"), "data"), ) - def update_graph(bs, dos, nclick): - if isinstance(bs, dict): - bs = PhononBandStructureSymmLine.from_dict(bs) - if isinstance(dos, dict): - dos = CompletePhononDos.from_dict(dos) + def update_graph(traces): + if traces == "error": + msg_body = MessageBody( + dcc.Markdown( + "Band structure and density of states not available for this selection." + ) + ) + return (MessageContainer([msg_body], kind="warning"),) + + if traces is None: + raise PreventUpdate + + bs, dos = self._get_ph_bs_dos(self.initial_data["default"]) figure = self.get_figure(bs, dos) + return dcc.Graph( + figure=figure, config={"displayModeBar": False}, responsive=True + ) - # remove marker if there is one - figure["data"] = [ - t for t in figure["data"] if t.get("name") != "click-marker" - ] + @app.callback( + Output(self.id("label-select"), "value"), + Output(self.id("label-container"), "style"), + Input(self.id("mpid"), "data"), + Input(self.id("path-convention"), "value"), + ) + def update_label_select(mpid, path_convention): + if not mpid: + raise PreventUpdate + label_value = path_convention + label_style = {"maxWidth": "200"} - x_click = nclick["points"][0]["x"] if nclick else 0 - y_click = nclick["points"][0]["y"] if nclick else 0 - pt = nclick["points"][0] if nclick else {} + return label_value, label_style - qpoint, band_num = pt.get("customdata", [0, 0]) + @app.callback( + Output(self.id("dos-select"), "options"), + Output(self.id("path-convention"), "options"), + Output(self.id("path-container"), "style"), + Input(self.id("elements"), "data"), + Input(self.id("mpid"), "data"), + ) + def update_select(elements, mpid): + if elements is None: + raise PreventUpdate + if not mpid: + dos_options = ( + [{"label": "Element Projected", "value": "ap"}] + + [{"label": "Orbital Projected - Total", "value": "op"}] + + [ + { + "label": "Orbital Projected - " + str(ele_label), + "value": "orb" + str(ele_label), + } + for ele_label in elements + ] + ) - figure["data"].append( - { - "type": "scatter", - "mode": "markers", - "x": [x_click], - "y": [y_click], - "marker": { - "color": MARKER_COLOR, - "size": MARKER_SIZE, - "symbol": MARKER_SHAPE, - }, - "name": "click-marker", - "showlegend": False, - "customdata": [[qpoint, band_num]], - "hovertemplate": ( - "band: %{customdata[1]}
q-point: %{customdata[0]}
" - ), - } + path_options = [{"label": "N/A", "value": "sc"}] + path_style = {"maxWidth": "200", "display": "none"} + + return dos_options, path_options, path_style + dos_options = ( + [{"label": "Element Projected", "value": "ap"}] + + [{"label": "Orbital Projected - Total", "value": "op"}] + + [ + { + "label": "Orbital Projected - " + str(ele_label), + "value": "orb" + str(ele_label), + } + for ele_label in elements + ] ) - zone_scene = self.get_brillouin_zone_scene(bs) + path_options = [ + {"label": "Setyawan-Curtarolo", "value": "sc"}, + {"label": "Latimer-Munro", "value": "lm"}, + {"label": "Hinuma et al.", "value": "hin"}, + ] - summary_dict = self._get_data_list_dict(bs, dos) - summary_table = get_data_list(summary_dict) + path_style = {"maxWidth": "200"} - return figure, zone_scene.to_json(), summary_table + return dos_options, path_options, path_style + + @app.callback( + Output(self.id("traces"), "data"), + Output(self.id("elements"), "data"), + Input(self.id(), "data"), + Input(self.id("path-convention"), "value"), + Input(self.id("dos-select"), "value"), + Input(self.id("label-select"), "value"), + ) + def bs_dos_data(data, dos_select, label_select): + # Obtain bands to plot over and generate traces for bs data: + energy_window = (-6.0, 10.0) + + traces = [] + + bsml, density_of_states = self._get_ph_bs_dos(data) + + if self.bandstructure_symm_line: + bs_traces = self.get_ph_bandstructure_traces( + bsml, freq_range=energy_window + ) + traces.append(bs_traces) + + if self.density_of_states: + dos_traces = self.get_ph_dos_traces( + density_of_states, freq_range=energy_window + ) + traces.append(dos_traces) + + # traces = [bs_traces, dos_traces, bs_data] + + # TODO: not tested if this is correct way to get element list + elements = list(map(str, density_of_states.get_element_dos())) + + return traces, elements @app.callback( Output(self.id("brillouin-zone"), "data"), @@ -887,78 +711,8 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): # TODO: figure out what to return (CSS?) to highlight BZ edge/point return - @app.callback( - Output(self.id("crystal-animation"), "data"), - Input(self.id("ph-bsdos-graph"), "clickData"), - Input(self.id("ph_bs"), "data"), - Input(self.id("supercell-controls-btn"), "n_clicks"), - Input(self.get_kwarg_id("magnitude"), "value"), - State(self.get_kwarg_id("scale-x"), "value"), - State(self.get_kwarg_id("scale-y"), "value"), - State(self.get_kwarg_id("scale-z"), "value"), - # prevent_initial_call=True - ) - def update_crystal_animation( - cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z - ): - # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields, - # ensuring updates occur only after the `supercell-controls-btn`` is clicked. - - if not bs: - raise PreventUpdate - - # Since `self.get_kwarg_id()` uses dash.dependencies.ALL, it returns a list of values. - # Although we could use `magnitude_fraction = magnitude_fraction[0]` to get the first value, - # this approach provides better clarity and readability. - kwargs = self.reconstruct_kwargs_from_state() - magnitude_fraction = kwargs.get("magnitude") - scale_x = kwargs.get("scale-x") - scale_y = kwargs.get("scale-y") - scale_z = kwargs.get("scale-z") - - if isinstance(bs, dict): - bs = PhononBandStructureSymmLine.from_dict(bs) - - struct = bs.structure - total_repeat_cell_cnt = 1 - # update structure if the controls got triggered - if sueprcell_update: - total_repeat_cell_cnt = scale_x * scale_y * scale_z - - # create supercell - trans = SupercellTransformation( - ((scale_x, 0, 0), (0, scale_y, 0), (0, 0, scale_z)) - ) - struct = trans.apply_transformation(struct) - - struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) - scene = struc_graph.get_scene( - draw_image_atoms=False, - bonded_sites_outside_unit_cell=False, - site_get_scene_kwargs={"retain_atom_idx": True}, - ) - json_data = scene.to_json() - - qpoint = 0 - band_num = 0 - - if cd and cd.get("points"): - pt = cd["points"][0] - qpoint, band_num = pt.get("customdata", [0, 0]) - - # magnitude - magnitude = ( - MAX_MAGNITUDE - MIN_MAGNITUDE - ) * magnitude_fraction + MIN_MAGNITUDE - - return PhononBandstructureAndDosComponent._get_eigendisplacement( - ph_bs=bs, - json_data=json_data, - band=band_num, - qpoint=qpoint, - total_repeat_cell_cnt=total_repeat_cell_cnt, - magnitude=magnitude, - ) + # TODO: figure out what to return (CSS?) to highlight BZ edge/point + return class PhononBandstructureAndDosPanelComponent(PanelComponent):