Skip to content

Commit

Permalink
Bug fix in fermisurface2d plot. skimage contour output was mesh index…
Browse files Browse the repository at this point in the history
… points instead of kmesh grid. Needed an interpolation to map back to kmesh from the mesh index
  • Loading branch information
lllangWV committed Aug 27, 2024
1 parent 94414e4 commit 8140afa
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions pyprocar/core/fermisurface.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def plot(self,mode:str, interpolation=500):
xmin : xmax : interpolation * 1j, ymin : ymax : interpolation * 1j
]

unique_x = xnew[:,0]
unique_y = ynew[0,:]

# interpolation
n_spins = self.bands.shape[2]
for i_spin in range(n_spins):
Expand Down Expand Up @@ -209,7 +212,6 @@ def plot(self,mode:str, interpolation=500):
bnew = np.array(bnew)

# Generates colors per band

n_bands = bands.shape[0]
cmap = cm.get_cmap(self.config['cmap']['value'])
if i_spin == 1:
Expand All @@ -219,10 +221,20 @@ def plot(self,mode:str, interpolation=500):
solid_color_surface = np.arange(n_bands) / n_bands + factor
band_colors = np.array([cmap(norm(x)) for x in solid_color_surface[:]]).reshape(-1, 4)
plots = []

for i_band,band_energies in enumerate(bnew):
contours = measure.find_contours(band_energies, self.energy)
for i_contour,contour in enumerate(contours):
points = np.array([contour[:, 0], contour[:, 1]]).T.reshape(-1, 1, 2)
# measure.find contours returns a list of coordinates indcies of the mesh.
# However, due to the algorithm they take values that are in between mesh points.
# We need to interpolate the values to the original kmesh
x_vals=contour[:,0]
y_vals=contour[:,1]
x_interp=np.interp(x_vals, np.arange(0,unique_x.shape[0]), unique_x)
y_interp=np.interp(y_vals, np.arange(0,unique_y.shape[0]), unique_y)
points=np.array([[x_interp,y_interp]])
points=np.moveaxis(points, -1, 0)

segments = np.concatenate([points[:-1], points[1:]], axis=1)
if mode=='plain':
lc = LineCollection(segments,
Expand All @@ -243,7 +255,7 @@ def plot(self,mode:str, interpolation=500):
label=f'Band {band_labels[i_band]}'
lc.set_label(label)
if mode=='parametric':
c = griddata((x, y), spd[i_band,:], (contour[:, 0], contour[:, 1]), method="nearest")
c = griddata((x, y), spd[i_band,:], (x_interp, y_interp), method="cubic")
lc = LineCollection(segments, cmap=plt.get_cmap(self.config['cmap']['value']), norm=norm)
lc.set_array(c)

Expand Down Expand Up @@ -297,6 +309,7 @@ def spin_texture(self, sx, sy, sz,
# selecting components of K-points
x, y = self.kpoints[:, 0], self.kpoints[:, 1]


if self.band_indices is None:
bands = self.bands[:,self.useful_bands_by_spins[0],0].transpose()
band_labels = np.unique(self.useful_bands_by_spins[0])
Expand Down Expand Up @@ -367,7 +380,7 @@ def spin_texture(self, sx, sy, sz,
if self.config['arrow_size']['value'] is not None:
# This is so the density scales the way you think. increasing number means increasing density.
# The number in the numerator is so it scales reasonable with 0-20
scale = 10/self.config['arrow_size']['value']
scale = 10/self.config['arrow_size']['value']
scale_units = "xy"
angles="xy"
else:
Expand Down

0 comments on commit 8140afa

Please sign in to comment.