Skip to content

Commit

Permalink
Merge pull request #63 from SimonBoothroyd/dipole_shape
Browse files Browse the repository at this point in the history
Dipole reshape
  • Loading branch information
jthorton authored May 30, 2023
2 parents ba61333 + 683980f commit 6c6cf2e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nagl/training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def evaluate_loss(
prediction: typing.Dict[str, torch.Tensor],
) -> torch.Tensor:
metric_func = get_metric(self.metric)
target_dipole = labels[self.dipole_column].squeeze()
# reshape as it can be flat
target_dipole = torch.reshape(labels[self.dipole_column], (-1, 3))
n_atoms_per_molecule = (
(molecules.n_atoms,)
if isinstance(molecules, DGLMolecule)
Expand Down

0 comments on commit 6c6cf2e

Please sign in to comment.