diff --git a/aiidalab_sssp/inspect/subwidgets/bands.py b/aiidalab_sssp/inspect/subwidgets/bands.py index 1df8e0b..91731ca 100644 --- a/aiidalab_sssp/inspect/subwidgets/bands.py +++ b/aiidalab_sssp/inspect/subwidgets/bands.py @@ -20,6 +20,8 @@ _RY_TO_EV = 13.6056980659 _FERMI_SHIFT = 10.0 # eV in protocol FIXME also change title of plot Tab widget +_SMEARING_WIDTH = _DEGAUSS * _RY_TO_EV + def _bandview(json_path): """ @@ -85,27 +87,41 @@ def _on_pseudo_select(self, _): path = pseudo1["accuracy"]["bands"]["band_structure"] json_path = Path.joinpath(SSSP_DB, path) - band = _bandview(json_path) - if band: - bands.append(band) + bandsdata_a = self.bands_align_to_fermi(_bandview(json_path)) + bands.append(bandsdata_a) if pseudo2: path = pseudo2["accuracy"]["bands"]["band_structure"] json_path = Path.joinpath(SSSP_DB, path) - band = _bandview(json_path) - if band: - bands.append(band) + bandsdata_b = self.bands_align_to_fermi(_bandview(json_path)) + bands.append(bandsdata_b) _band_structure_preview = BandsPlotWidget( bands=bands, - energy_range={"ymin": -30.0, "ymax": 11.0}, + energy_range={"ymin": -10.0, "ymax": 15.0}, + fermi_energy=0.0, # since we have aligned to fermi level ) with self.band_structure: clear_output(wait=True) display(_band_structure_preview) + def bands_align_to_fermi(self, bandsdata): + """ + align the band structure to fermi level + """ + fermi_energy = bandsdata["fermi_level"] + + for path in bandsdata["paths"]: + values = [[y - fermi_energy for y in ys] for ys in path["values"]] + path["values"] = values + + # After align to fermi level, we need to update the fermi level to 0.0 + bandsdata["fermi_level"] = 0.0 + + return bandsdata + class BandChessboard(ipw.VBox): """Band distance compare in chess board""" @@ -166,7 +182,7 @@ def _render_plot(ax_v, ax_c, arr_v, arr_c, labels): for idx, (ax, arr, title) in enumerate( [(ax_v, arr_v, r"$\eta_v$"), (ax_c, arr_c, r"$\eta_{10}$")] ): - ax.imshow(arr) + ax.imshow(arr, vmin=0, vmax=50, cmap="viridis") # Show all ticks and label them with the respective list entries # We want to show all ticks... @@ -234,7 +250,7 @@ def _bands_distance(self, pseudos): distance = get_bands_distance( bandsdata_a=bandsdata1, bandsdata_b=bandsdata2, - smearing=_DEGAUSS * _RY_TO_EV, + smearing=_SMEARING_WIDTH, fermi_shift=fermi_shift, do_smearing=do_smearing, spin=spin,