Skip to content

Commit

Permalink
Work around lack of inheritance in TorchScript.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Aug 12, 2024
1 parent 1d9f02f commit 6d19c27
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/sire/qm/_emle.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def emle(
)

try:
import torch as _torch
from emle.models import EMLE as _EMLE

has_model = True
Expand All @@ -181,7 +182,17 @@ def emle(
raise ValueError("Unable to select 'qm_atoms' from 'mols'")

if has_model:
if not isinstance(calculator, (_EMLECalculator, _EMLE)):
# EMLECalculator.
if isinstance(calculator, _EMLECalculator):
pass
# EMLE model. Note that TorchScript doesn't support inheritance, so
# we need to check whether this is a torch.nn.Module and whether it
# has the "_is_emle" attribute, which is added to all EMLE models.
elif isinstance(calculator, _torch.nn.Module) and hasattr(
calculator, "_is_emle"
):
pass
else:
raise TypeError(
"'calculator' must be a of type 'emle.calculator.EMLECalculator' or 'emle.models.EMLE'"
)
Expand Down

0 comments on commit 6d19c27

Please sign in to comment.