Skip to content

Commit

Permalink
Fixes to batched imlementation
Browse files Browse the repository at this point in the history
  • Loading branch information
JMorado committed Nov 25, 2024
1 parent f808d0d commit 800306d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions emle/models/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def forward(
charges_mm: Tensor,
xyz_qm: Tensor,
xyz_mm: Tensor,
qm_charge: Union[int, Tensor] = _torch.tensor(0, dtype=_torch.int64),
qm_charge: Union[int, Tensor] = 0,
) -> Tensor:
"""
Computes the static and induced EMLE energy components.
Expand Down Expand Up @@ -456,7 +456,7 @@ def forward(
qm_charge = _torch.full((batch_size,), qm_charge if qm_charge != 0 else self._qm_charge, dtype=_torch.int64, device=self._device)
elif isinstance(qm_charge, _torch.Tensor):
if qm_charge.ndim == 0:
qm_charge = qm_charge.repeat(batch_size)
qm_charge = qm_charge.repeat(batch_size).to(self._device)

# If there are no point charges, return zeros.
if xyz_mm.shape[1] == 0:
Expand All @@ -479,7 +479,7 @@ def forward(

# Compute the static energy.
if self._method == "mm":
q_core = self._q_core_mm
q_core = self._q_core_mm.expand(batch_size, -1)
q_val = _torch.zeros_like(
q_core, dtype=self._charges_mm.dtype, device=self._device
)
Expand Down

0 comments on commit 800306d

Please sign in to comment.