diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 7e1c8bb..1a3b4bc 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -408,7 +408,7 @@ def forward( charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, - qm_charge: Union[int, Tensor] = _torch.tensor(0, dtype=_torch.int64), + qm_charge: Union[int, Tensor] = 0, ) -> Tensor: """ Computes the static and induced EMLE energy components. @@ -456,7 +456,7 @@ def forward( qm_charge = _torch.full((batch_size,), qm_charge if qm_charge != 0 else self._qm_charge, dtype=_torch.int64, device=self._device) elif isinstance(qm_charge, _torch.Tensor): if qm_charge.ndim == 0: - qm_charge = qm_charge.repeat(batch_size) + qm_charge = qm_charge.repeat(batch_size).to(self._device) # If there are no point charges, return zeros. if xyz_mm.shape[1] == 0: @@ -479,7 +479,7 @@ def forward( # Compute the static energy. if self._method == "mm": - q_core = self._q_core_mm + q_core = self._q_core_mm.expand(batch_size, -1) q_val = _torch.zeros_like( q_core, dtype=self._charges_mm.dtype, device=self._device )