Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 21 additions & 27 deletions src/phelel/velph/cli/el_bands/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,9 @@
import click
import h5py
import numpy as np
from numpy.typing import NDArray

from phelel.velph.cli.utils import (
get_distances_along_BZ_path,
get_reclat_from_vaspout,
get_special_points,
)
from phelel.velph.utils.vasp import get_bands_data, get_reclat_from_vaspout


def plot_el_bandstructures(
Expand Down Expand Up @@ -48,14 +45,14 @@ def plot_el_bandstructures(
)
if "results" not in f_h5py_dos:
raise ValueError(f"No electronic DOS results found in {vaspout_filename_dos}.")
if "electron_dos_kpoints_opt" in f_h5py_dos["results"]:
if "electron_dos_kpoints_opt" in f_h5py_dos["results"]: # type: ignore
f_h5py_dos_results = f_h5py_dos["results/electron_dos_kpoints_opt"]
elif "electron_dos" in f_h5py_dos["results"]:
elif "electron_dos" in f_h5py_dos["results"]: # type: ignore
f_h5py_dos_results = f_h5py_dos["results/electron_dos"]
else:
raise ValueError("No electron DOS data found in vaspout.h5.")

efermi = f_h5py_dos_results["efermi"][()]
efermi: float = f_h5py_dos_results["efermi"][()] # type: ignore
emin = window[0]
emax = window[1]
_, axs = plt.subplots(1, 2, gridspec_kw={"width_ratios": [3, 1]})
Expand All @@ -64,8 +61,8 @@ def plot_el_bandstructures(
reclat = get_reclat_from_vaspout(f_h5py_bands)
distances, eigvals, points, labels_at_points = _get_bands_data(
reclat,
f_h5py_bands["results/electron_eigenvalues_kpoints_opt"],
f_h5py_bands["input/kpoints_opt"],
f_h5py_bands["results/electron_eigenvalues_kpoints_opt"], # type: ignore
f_h5py_bands["input/kpoints_opt"], # type: ignore
)

lines = ["-k", "--k"]
Expand All @@ -81,7 +78,9 @@ def plot_el_bandstructures(
ax0.set_ylabel("E[eV]", fontsize=14)
ymin, ymax = ax0.get_ylim()

dos, energies, xmax = _get_dos_data(f_h5py_dos_results, ymin, ymax)
dos: NDArray = f_h5py_dos_results["dos"][:] # type: ignore
energies: NDArray = f_h5py_dos_results["energies"][:] # type: ignore
xmax = _get_dos_data(dos, energies, ymin, ymax)

lines = ["-k", "--k"]
for i, dos_spin in enumerate(dos):
Expand All @@ -104,32 +103,27 @@ def plot_el_bandstructures(


def _get_bands_data(
reclat: np.ndarray, f_h5py_bands_results: h5py.Group, f_h5py_bands_input: h5py.Group
reclat: NDArray, f_h5py_bands_results: h5py.Group, f_h5py_bands_input: h5py.Group
):
eigvals = f_h5py_bands_results["eigenvalues"][:]
eigvals: NDArray = f_h5py_bands_results["eigenvalues"][:] # type: ignore

# k-points in reduced coordinates
kpoint_coords = f_h5py_bands_results["kpoint_coords"]
kpoint_coords: NDArray = f_h5py_bands_results["kpoint_coords"][:] # type: ignore
# Special point labels
labels = [
label.decode("utf-8") for label in f_h5py_bands_input["labels_kpoints"][:]
label.decode("utf-8")
for label in f_h5py_bands_input["labels_kpoints"][:] # type: ignore
]
nk_per_seg = f_h5py_bands_input["number_kpoints"][()]
nk_total = len(kpoint_coords)
k_cart = kpoint_coords @ reclat
n_segments = nk_total // nk_per_seg
assert n_segments * nk_per_seg == nk_total
distances = get_distances_along_BZ_path(nk_total, n_segments, nk_per_seg, k_cart)
points, labels_at_points = get_special_points(
labels, distances, n_segments, nk_per_seg, nk_total
nk_per_seg: int = f_h5py_bands_input["number_kpoints"][()] # type: ignore

distances, points, labels_at_points = get_bands_data(
kpoint_coords, reclat, nk_per_seg, labels
)

return distances, eigvals, points, labels_at_points


def _get_dos_data(f_h5py_dos_results: h5py.Group, ymin: float, ymax: float):
dos = f_h5py_dos_results["dos"][:]
energies = f_h5py_dos_results["energies"][:]
def _get_dos_data(dos: NDArray, energies: NDArray, ymin: float, ymax: float) -> float:
i_min = 0
i_max = len(energies)
for i, val in enumerate(energies):
Expand All @@ -154,4 +148,4 @@ def _get_dos_data(f_h5py_dos_results: h5py.Group, ymin: float, ymax: float):
dos = np.where(dos > 10000, 0, dos)
xmax = dos[:, i_min : i_max + 1].max() * 1.1

return dos, energies, xmax
return xmax
18 changes: 17 additions & 1 deletion src/phelel/velph/cli/init/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
VelphFilePaths,
VelphInitOptions,
VelphInitParams,
)
from phelel.velph.templates import default_template_dict
from phelel.velph.utils.structure import (
generate_standardized_cells,
get_primitive_cell,
get_reduced_cell,
get_symmetry_dataset,
)
from phelel.velph.templates import default_template_dict
from phelel.velph.utils.vasp import CutoffToFFTMesh, VaspIncar
from phelel.version import __version__

Expand Down Expand Up @@ -549,6 +551,20 @@ def _get_cells(
click.echo(f" Reference space-group-type: {spg_type.international_short}")

if symmetrize_cell:
if isinstance(sym_dataset, SpglibDataset):
click.echo(
"Crystal structure was standardized based on space-group-type "
f"{sym_dataset.international}."
)
elif isinstance(sym_dataset, SpglibMagneticDataset):
click.echo(
"Crystal structure was standardized based on magnetic-space-group-type "
f"UNI No.{sym_dataset.uni_number}."
)
else:
raise ValueError(
"sym_dataset must be SpglibDataset or SpglibMagneticDataset."
)
unitcell, _primitive, tmat = generate_standardized_cells(
sym_dataset, tolerance=tolerance
)
Expand Down
28 changes: 10 additions & 18 deletions src/phelel/velph/cli/ph_bands/plot.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Implementation of velph-ph_bands-plot."""

from __future__ import annotations

import pathlib

import click
import h5py
import numpy as np
from numpy.typing import NDArray

from phelel.velph.cli.utils import (
get_distances_along_BZ_path,
get_reclat_from_vaspout,
get_special_points,
)
from phelel.velph.utils.vasp import get_bands_data, get_reclat_from_vaspout


def plot_ph_bandstructures(
Expand All @@ -37,24 +36,17 @@ def plot_ph_bandstructures(
import matplotlib.pyplot as plt

f = h5py.File(vaspout_filename)
eigvals = f["results"]["phonons"]["eigenvalues"][:] # phonon eigenvalues
eigvals: NDArray = f["results/phonons/eigenvalues"][:] # type: ignore
if use_ordinary_frequency:
eigvals /= 2 * np.pi
omega_max = 1.1 * eigvals.max()

reclat = get_reclat_from_vaspout(f)
labels = [
label.decode("utf-8") for label in f["input"]["qpoints"]["labels_kpoints"][:]
]
nk_per_seg = f["input"]["qpoints"]["number_kpoints"][()]
kpoint_coords = f["results"]["phonons"]["kpoint_coords"]
nk_total = len(kpoint_coords)
k_cart = kpoint_coords @ reclat
n_segments = nk_total // nk_per_seg
assert n_segments * nk_per_seg == nk_total
distances = get_distances_along_BZ_path(nk_total, n_segments, nk_per_seg, k_cart)
points, labels_at_points = get_special_points(
labels, distances, n_segments, nk_per_seg, nk_total
labels = [label.decode("utf-8") for label in f["input/qpoints/labels_kpoints"][:]] # type: ignore
nk_per_seg: int = f["input/qpoints/number_kpoints"][()] # type: ignore
kpoint_coords: NDArray = f["results/phonons/kpoint_coords"][:] # type: ignore
distances, points, labels_at_points = get_bands_data(
kpoint_coords, reclat, nk_per_seg, labels
)

_, ax = plt.subplots()
Expand Down
Loading
Loading