diff --git a/tlm_adjoint/torch.py b/tlm_adjoint/torch.py index 7ba77139..6be0cd6c 100644 --- a/tlm_adjoint/torch.py +++ b/tlm_adjoint/torch.py @@ -30,11 +30,14 @@ ] -def to_torch_tensor(x, *args, **kwargs): - return torch.tensor(var_get_values(x), *args, **kwargs) +def to_torch_tensor(x, *args, conjugate=False, **kwargs): + x_a = var_get_values(x) + if conjugate: + x_a = x_a.conjugate() + return torch.tensor(x_a, *args, **kwargs) -def to_torch_tensors(X, *args, **kwargs): +def to_torch_tensors(X, *args, conjugate=False, **kwargs): """Convert one or more variables to :class:`torch.Tensor` objects. Parameters @@ -42,6 +45,8 @@ def to_torch_tensors(X, *args, **kwargs): X : variable or Sequence[variable, ...] Variables to be converted. + conjugate : bool + Whether to copy the complex conjugate. args, kwargs Passed to :func:`torch.tensor`. @@ -52,15 +57,19 @@ def to_torch_tensors(X, *args, **kwargs): The converted variables. """ - return tuple(to_torch_tensor(x, *args, **kwargs) for x in packed(X)) + return tuple(to_torch_tensor(x, *args, conjugate=conjugate, **kwargs) + for x in packed(X)) -def from_torch_tensor(x, x_t): - var_set_values(x, x_t.detach().numpy()) +def from_torch_tensor(x, x_t, *, conjugate=False): + x_a = x_t.detach().numpy() + if conjugate: + x_a = x_a.conjugate() + var_set_values(x, x_a) return x -def from_torch_tensors(X, X_t): +def from_torch_tensors(X, X_t, *, conjugate=False): """Copy data from PyTorch tensors into variables. Parameters @@ -70,13 +79,15 @@ def from_torch_tensors(X, X_t): Output. X_t : Sequence[:class:`torch.Tensor`, ...] Input. + conjugate : bool + Whether to copy the complex conjugate. """ X = packed(X) if len(X) != len(X_t): raise ValueError("Invalid length") for x, x_t in zip(X, X_t): - from_torch_tensor(x, x_t) + from_torch_tensor(x, x_t, conjugate=conjugate) @restore_manager @@ -123,11 +134,12 @@ def backward(ctx, *adj_X_t): clear_caches=clear_caches) J_id[0] = var_id(J) - from_torch_tensors(adj_X, adj_X_t) + from_torch_tensors(adj_X, adj_X_t, conjugate=True) set_manager(manager) dJ = compute_gradient(J, M) - return (None, None, None, None, None) + to_torch_tensors(dJ) + return ((None, None, None, None, None) + + to_torch_tensors(dJ, conjugate=True)) def torch_wrapped(forward, space, *, manager=None, clear_caches=True):