Skip to content

Commit

Permalink
mypy fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
esoteric-ephemera committed Nov 21, 2024
1 parent 3c44175 commit efbbe55
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions emmet-core/emmet/core/neb.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,12 @@ class NebTaskDoc(BaseModel, extra="allow"):
@model_validator(mode="after")
def set_barriers(self) -> Self:
"""Perform analysis on barrier if needed."""
if not self.forward_barrier or not self.reverse_barrier:
self.barrier_analysis = neb_barrier_spline_fit(self.energies)
if (
not self.forward_barrier
or not self.reverse_barrier
and self.energies is not None
):
self.barrier_analysis = neb_barrier_spline_fit(self.energies) # type: ignore[arg-type]
for k in ("forward", "reverse"):
setattr(self, f"{k}_barrier", self.barrier_analysis[f"{k}_barrier"])
return self
Expand All @@ -140,10 +144,10 @@ def num_images(self) -> int:
return len(self.image_directories)

@property
def energies(self) -> list[float]:
def energies(self) -> list[float] | None:
"""Return the endpoint (optional) and image energies."""
if self.endpoint_energies is not None:
return [
return [ # type: ignore[misc]
self.endpoint_energies[0],
*self.image_energies,
self.endpoint_energies[1],
Expand Down Expand Up @@ -345,7 +349,7 @@ def neb_barrier_spline_fit(
"energies": list(energies),
"frame_index": list(frame_idx := np.linspace(0.0, 1.0, len(energies))),
}
energies = np.array(energies)
energies = np.array(energies) # type: ignore[assignment]

spline_kwargs = spline_kwargs or {"bc_type": "clamped"}
spline_fit = CubicSpline(frame_idx, energies, **spline_kwargs)
Expand Down

0 comments on commit efbbe55

Please sign in to comment.