diff --git a/pyprocar/plotter/ebs_plot.py b/pyprocar/plotter/ebs_plot.py index 4ffcd1e1..79a18dbe 100644 --- a/pyprocar/plotter/ebs_plot.py +++ b/pyprocar/plotter/ebs_plot.py @@ -5,10 +5,11 @@ import os import yaml +import json from typing import List import numpy as np - +import pandas as pd import matplotlib.pyplot as plt from matplotlib.collections import LineCollection import matplotlib as mpl @@ -52,6 +53,9 @@ def __init__(self, self.kpath = kpath self.spins = spins self.kdirect=kdirect + + self.values_dict={} + if self.spins is None: self.spins = range(self.ebs.nspins) self.nspins = len(self.spins) @@ -128,7 +132,7 @@ def plot_bands(self): None. """ - + values_dict={} for ispin in self.spins: if len(self.spins)==1: @@ -147,12 +151,31 @@ def plot_bands(self): ) self.handles.append(handle) + + + + band_name=f'band-{iband}_spinChannel-{str(ispin)}' + values_dict[f'bands_{band_name}']=self.ebs.bands[:, iband, ispin] + + values_dict['kpath_values']=self.x + tick_names=[] + for i,x in enumerate(self.x): + tick_name='' + for i_tick, tick_position in enumerate(self.kpath.tick_positions): + if i == tick_position: + tick_name=self.kpath.tick_names[i_tick] + tick_names.append(tick_name) + values_dict['kpath_tick_names']=tick_names + self.values_dict=values_dict + + def plot_scatter(self, width_mask:np.ndarray=None, color_mask:np.ndarray=None, spins:List[int]=None, width_weights:np.ndarray=None, color_weights:np.ndarray=None, + labels=None, ): """A method to plot a scatter plot @@ -169,6 +192,8 @@ def plot_scatter(self, color_weights : np.ndarray, optional The color weights at each point, by default None """ + values_dict={} + if spins is None: spins = range(self.ebs.nspins) if self.ebs.is_non_collinear: @@ -214,8 +239,7 @@ def plot_scatter(self, self.x, mbands[:, iband, ispin], c=color, - s=width_weights[:, iband, ispin].round( - 2)*markersize[ispin], + s=width_weights[:, iband, ispin]*markersize[ispin], # edgecolors="none", linewidths=self.config.linewidth[ispin], cmap=self.config.cmap, @@ -228,8 +252,8 @@ def plot_scatter(self, sc = self.ax.scatter( self.x, mbands[:, iband, ispin], - c=color_weights[:, iband, ispin].round(2), - s=width_weights[:, iband, ispin].round(2)*markersize[ispin], + c=color_weights[:, iband, ispin], + s=width_weights[:, iband, ispin]*markersize[ispin], # edgecolors="none", linewidths=self.config.linewidth[ispin], cmap=self.config.cmap, @@ -238,9 +262,31 @@ def plot_scatter(self, marker=self.config.marker[ispin], alpha=self.config.opacity[ispin], ) + + band_name=f'band-{iband}_spinChannel-{str(ispin)}' + values_dict[f'bands__{band_name}']=self.ebs.bands[:, iband, ispin] + projection_name=labels[0] + if color_weights is not None: + values_dict[f'projections__{projection_name}__{band_name}']=color_weights[:, iband, ispin] + + + if self.config.plot_color_bar and color_weights is not None: self.cb = self.fig.colorbar(sc, ax=self.ax) + + values_dict['kpath_values']=self.x + tick_names=[] + for i,x in enumerate(self.x): + tick_name='' + for i_tick, tick_position in enumerate(self.kpath.tick_positions): + if i == tick_position: + tick_name=self.kpath.tick_names[i_tick] + tick_names.append(tick_name) + values_dict['kpath_tick_names']=tick_names + + self.values_dict=values_dict + def plot_parameteric( self, spins:List[int]=None, @@ -248,7 +294,8 @@ def plot_parameteric( color_mask:np.ndarray=None, width_weights:np.ndarray=None, color_weights:np.ndarray=None, - elimit:List[float]=None + elimit:List[float]=None, + labels=None ): """A method to plot a scatter plot @@ -267,6 +314,7 @@ def plot_parameteric( elimit : List[float], optional Energy range to plot. Only useful if the band index is written """ + values_dict={} # if there is only a single k-point the method for atomic # levels will be called to fake another kpoint and then @@ -338,6 +386,12 @@ def plot_parameteric( width_weights[:, iband, ispin]*linewidth[ispin]) lc.set_linestyle(self.config.linestyle[ispin]) handle = self.ax.add_collection(lc) + + band_name=f'band-{iband}_spinChannel-{str(ispin)}' + projection_name=labels[0] + values_dict[F'bands__{band_name}']=self.ebs.bands[:, iband, ispin] + if color_weights is not None: + values_dict[F'projections__{projection_name}__{band_name}']=color_weights[:, iband, ispin] # if color_weights is not None: # handle.set_color(color_map[iweight][:-1].lower()) handle.set_linewidth(linewidth) @@ -346,9 +400,21 @@ def plot_parameteric( if self.config.plot_color_bar and color_weights is not None: self.cb = self.fig.colorbar(lc, ax=self.ax) + values_dict['kpath_values']=self.x + tick_names=[] + for i,x in enumerate(self.x): + tick_name='' + for i_tick, tick_position in enumerate(self.kpath.tick_positions): + if i == tick_position: + tick_name=self.kpath.tick_names[i_tick] + tick_names.append(tick_name) + values_dict['kpath_tick_names']=tick_names + self.values_dict=values_dict + def plot_parameteric_overlay(self, spins:List[int]=None, weights:np.ndarray=None, + labels:str=None, ): """A method to plot the parametric overlay @@ -359,6 +425,7 @@ def plot_parameteric_overlay(self, weights : np.ndarray, optional The weights of each point, by default None """ + values_dict={} linewidth = [l*7 for l in self.config.linewidth] if type(self.config.cmap) is str: @@ -395,6 +462,14 @@ def plot_parameteric_overlay(self, lc.set_array(weight[:, iband, ispin]) lc.set_linewidth(weight[:, iband, ispin]*linewidth[ispin]) handle = self.ax.add_collection(lc) + + + band_name=f'band-{iband}_spinChannel-{str(ispin)}' + projection_name=labels[iweight] + values_dict[f'bands__{band_name}']=self.ebs.bands[:, iband, ispin] + if weights is not None: + values_dict[f'projections__{projection_name}__{band_name}']=weight[:, iband, ispin] + handle.set_color(color_map[iweight][:-1].lower()) handle.set_linewidth(linewidth) self.handles.append(handle) @@ -402,13 +477,25 @@ def plot_parameteric_overlay(self, if self.config.plot_color_bar: self.cb = self.fig.colorbar(lc, ax=self.ax) + values_dict['kpath_values']=self.x + tick_names=[] + for i,x in enumerate(self.x): + tick_name='' + for i_tick, tick_position in enumerate(self.kpath.tick_positions): + if i == tick_position: + tick_name=self.kpath.tick_names[i_tick] + tick_names.append(tick_name) + values_dict['kpath_tick_names']=tick_names + self.values_dict=values_dict + def plot_atomic_levels(self, spins:List[int]=None, width_mask:np.ndarray=None, color_mask:np.ndarray=None, width_weights:np.ndarray=None, color_weights:np.ndarray=None, - elimit:List[float]=None + elimit:List[float]=None, + labels=None ): """A method to plot a scatter plot @@ -500,7 +587,8 @@ def plot_atomic_levels(self, width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, - spins=spins) + spins=spins, + labels=labels) def set_xticks(self, tick_positions:List[int]=None, @@ -726,4 +814,61 @@ def save(self, filename:str='bands.pdf'): plt.savefig(filename, dpi=self.config.dpi, bbox_inches="tight") plt.clf() + def export_data(self,filename): + """ + This method will export the data to a csv file + + Parameters + ---------- + filename : str + The file name to export the data to + + Returns + ------- + None + None + """ + possible_file_types=['csv','txt','json','dat'] + file_type=filename.split('.')[-1] + if file_type not in possible_file_types: + raise ValueError(f"The file type must be {possible_file_types}") + if self.values_dict is None: + raise ValueError("The data has not been plotted yet") + + column_names=list(self.values_dict.keys()) + sorted_column_names=[None]*len(column_names) + index=0 + for column_name in column_names: + if 'kpath_values' == column_name: + sorted_column_names[index]=column_name + index+=1 + if 'kpath_tick_names' == column_name: + sorted_column_names[index]=column_name + index+=1 + for ispin in range(2): + for column_name in column_names: + + if 'spinChannel-0' in column_name.split('_')[-1] and ispin==0: + sorted_column_names[index]=column_name + index+=1 + if 'spinChannel-1' in column_name.split('_')[-1] and ispin==1: + sorted_column_names[index]=column_name + index+=1 + + column_names.sort() + if file_type=='csv': + df=pd.DataFrame(self.values_dict) + df.to_csv(filename, columns=sorted_column_names, index=False) + elif file_type=='txt': + df=pd.DataFrame(self.values_dict) + df.to_csv(filename, columns=sorted_column_names, sep='\t', index=False) + elif file_type=='json': + with open(filename, 'w') as outfile: + for key,value in self.values_dict.items(): + self.values_dict[key]=value.tolist() + json.dump(self.values_dict, outfile) + elif file_type=='dat': + df=pd.DataFrame(self.values_dict) + df.to_csv(filename, columns=sorted_column_names, sep=' ', index=False) + diff --git a/pyprocar/scripts/scriptBandsplot.py b/pyprocar/scripts/scriptBandsplot.py index 29a13913..17de7bf1 100644 --- a/pyprocar/scripts/scriptBandsplot.py +++ b/pyprocar/scripts/scriptBandsplot.py @@ -36,6 +36,8 @@ def bandsplot( show:bool=True, savefig:str=None, print_plot_opts:bool=False, + export_data_file:str=None, + export_append_mode:bool=True, **kwargs ): """A function to plot the band structutre @@ -79,6 +81,12 @@ def bandsplot( Boolean if to show the plot, by default True savefig : str, optional String to save the plot, by default None + export_data_file : str, optional + The file name to export the data to. If not provided the + data will not be exported. + export_append_mode : bool, optional + Boolean to append the mode to the file name. If not provided the + data will be overwritten. print_plot_opts: bool, optional Boolean to print the plotting options """ @@ -131,7 +139,7 @@ def bandsplot( ebs_plot = EBSPlot(ebs, kpath, ax, spins, kdirect=kdirect ,config=config) - + projection_labels=[] labels = [] if mode == "plain": ebs_plot.plot_bands() @@ -162,11 +170,13 @@ def bandsplot( elif mode in ["overlay", "overlay_species", "overlay_orbitals"]: weights = [] - if mode == "overlay_species": for ispc in structure.species: labels.append(ispc) atoms = np.where(structure.atoms == ispc)[0] + + projection_label=f'atom-{ispc}_orbitals-'+",".join(str(x) for x in orbitals) + projection_labels.append(projection_label) w = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], @@ -180,6 +190,13 @@ def bandsplot( continue orbitals = orbital_names[orb] labels.append(orb) + + atom_label='' + if atoms: + atom_labels=",".join(str(x) for x in atoms) + atom_label=f'atom-{atom_labels}_' + projection_label=f'{atom_label}orbitals-{orb}' + projection_labels.append(projection_label) w = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], @@ -204,6 +221,11 @@ def bandsplot( else: orbitals = it[ispc] labels.append(ispc + "-" + "_".join(str(x) for x in it[ispc])) + + atom_labels=",".join(str(x) for x in atoms) + orbital_labels=",".join(str(x) for x in orbitals) + projection_label=f'atoms-{atom_labels}_orbitals-{orbital_labels}' + projection_labels.append(projection_label) w = ebs_plot.ebs.ebs_sum( atoms=atoms, principal_q_numbers=[-1], @@ -211,7 +233,7 @@ def bandsplot( spins=spins, ) weights.append(w) - ebs_plot.plot_parameteric_overlay(spins=spins,weights=weights) + ebs_plot.plot_parameteric_overlay(spins=spins,weights=weights,labels=projection_labels) else: if atoms is not None and isinstance(atoms[0], str): atoms_str = atoms @@ -228,6 +250,19 @@ def bandsplot( for iorb in orbital_str: orbitals = np.append(orbitals, orbital_names[iorb]).astype(np.int) + projection_labels=[] + projection_label='' + atoms_labels='' + if atoms: + atoms_labels=",".join(str(x) for x in atoms) + projection_label+=f'atoms-{atoms_labels}' + orbital_labels='' + if orbitals: + orbital_labels=",".join(str(x) for x in orbitals) + if len(projection_label)!=0: + projection_label+='_' + projection_label+=f'orbitals-{orbital_labels}' + projection_labels.append(projection_label) weights = ebs_plot.ebs.ebs_sum(atoms=atoms, principal_q_numbers=[-1], orbitals=orbitals, spins=spins) if config.weighted_color: @@ -246,7 +281,8 @@ def bandsplot( width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, - spins=spins + spins=spins, + labels=projection_labels ) ebs_plot.set_colorbar_title() elif mode == "scatter": @@ -255,7 +291,8 @@ def bandsplot( width_weights=width_weights, color_mask=color_mask, width_mask=width_mask, - spins=spins + spins=spins, + labels=projection_labels ) ebs_plot.set_colorbar_title() elif mode == "atomic": @@ -269,7 +306,9 @@ def bandsplot( color_mask=color_mask, width_mask=width_mask, spins=spins, - elimit=elimit) + elimit=elimit, + labels=projection_labels + ) ebs_plot.set_xlabel(label='') ebs_plot.set_colorbar_title() @@ -295,5 +334,13 @@ def bandsplot( ebs_plot.save(savefig) if show: ebs_plot.show() + + if export_data_file is not None: + if export_append_mode: + file_basename,file_type=export_data_file.split('.') + filename=f"{file_basename}_{mode}.{file_type}" + else: + filename=export_data_file + ebs_plot.export_data(filename) return ebs_plot.fig, ebs_plot.ax