Skip to content

Commit

Permalink
Add var_to_petsc and var_from_petsc
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Jul 17, 2024
1 parent dbde2ec commit ed420a5
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 40 deletions.
12 changes: 11 additions & 1 deletion tlm_adjoint/fenics/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""

from .backend import (
backend_Constant, backend_Function, backend_ScalarType, backend_Vector)
as_backend_type, backend_Constant, backend_Function, backend_ScalarType,
backend_Vector)
from ..interface import (
SpaceInterface, VariableInterface, check_space_types,
register_subtract_adjoint_derivative_action,
Expand Down Expand Up @@ -430,6 +431,15 @@ def _set_values(self, values):
self.vector().set_local(values)
self.vector().apply("insert")

def _to_petsc(self, vec):
self_v = as_backend_type(self.vector()).vec()
self_v.copy(result=vec)

def _from_petsc(self, vec):
self_v = as_backend_type(self.vector()).vec()
vec.copy(result=self_v)
self.vector().apply("insert")

@check_vector
def _new(self, *, name=None, static=False, cache=None,
rel_space_type="primal"):
Expand Down
8 changes: 8 additions & 0 deletions tlm_adjoint/firedrake/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,14 @@ def _set_values(self, values):
with self.dat.vec_wo as x_v:
x_v.setArray(values)

def _to_petsc(self, vec):
with self.dat.vec_ro as self_v:
self_v.copy(result=vec)

def _from_petsc(self, vec):
with self.dat.vec_wo as self_v:
vec.copy(result=self_v)

def _is_replacement(self):
return False

Expand Down
49 changes: 47 additions & 2 deletions tlm_adjoint/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def call(obj, /, *args, **kwargs):
"var_comm",
"var_copy",
"var_dtype",
"var_from_petsc",
"var_get_values",
"var_global_size",
"var_id",
Expand All @@ -128,6 +129,7 @@ def call(obj, /, *args, **kwargs):
"var_space",
"var_space_type",
"var_state",
"var_to_petsc",
"var_update_caches",
"var_update_state",
"var_zero",
Expand Down Expand Up @@ -793,8 +795,9 @@ class VariableInterface:
"_state", "_update_state", "_is_static", "_is_cached", "_caches",
"_zero", "_assign", "_axpy", "_inner", "_linf_norm",
"_local_size", "_global_size", "_local_indices", "_get_values",
"_set_values", "_new", "_copy", "_replacement", "_is_replacement",
"_is_scalar", "_scalar_value", "_is_alias")
"_set_values", "_to_petsc", "_from_petsc", "_new", "_copy",
"_replacement", "_is_replacement", "_is_scalar", "_scalar_value",
"_is_alias")

def __init__(self):
raise RuntimeError("Cannot instantiate VariableInterface object")
Expand Down Expand Up @@ -862,6 +865,14 @@ def _get_values(self):
def _set_values(self, values):
raise NotImplementedError("Method not overridden")

def _to_petsc(self, vec):
values = var_get_values(self)
vec.setArray(values)

def _from_petsc(self, vec):
values = vec.getArray(True)
var_set_values(self, values)

def _new(self, *, name=None, static=False, cache=None,
rel_space_type="primal"):
space_type = var_space_type(self, rel_space_type=rel_space_type)
Expand Down Expand Up @@ -1327,6 +1338,40 @@ def var_get_values(x):
return values


@manager_disabled()
def var_to_petsc(x, vec):
"""Copy values from a variable into a :class:`petsc4py.PETSc.Vec`.
Does not update the :class:`petsc4py.PETSc.Vec` ghost.
Parameters
----------
x : variable
The input variable.
vec : :class:`petsc4py.PETSc.Vec`
The output :class:`petsc4py.PETSc.Vec`. The ghost is not updated.
"""

x._tlm_adjoint__var_interface_to_petsc(vec)


@manager_disabled()
def var_from_petsc(x, vec):
"""Copy values from a :class:`petsc4py.PETSc.Vec` into a variable.
Parameters
----------
x : variable
The output variable.
vec : :class:`petsc4py.PETSc.Vec`
The input :class:`petsc4py.PETSc.Vec`.
"""

x._tlm_adjoint__var_interface_from_petsc(vec)


@manager_disabled()
def var_set_values(x, values):
"""Set the process local degrees of freedom vector associated with a
Expand Down
67 changes: 30 additions & 37 deletions tlm_adjoint/petsc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from .interface import (
space_comm, space_global_size, space_local_size, var_dtype, var_get_values,
var_local_size, var_set_values)
space_comm, space_global_size, space_local_size, var_from_petsc,
var_to_petsc)

from contextlib import contextmanager
try:
import mpi4py.MPI as MPI
except ModuleNotFoundError:
MPI = None
import numpy as np
try:
import petsc4py.PETSc as PETSc
Expand Down Expand Up @@ -88,20 +92,27 @@ def __init__(self, spaces, *, dtype=None, comm=None):
dtype = PETSc.ScalarType
dtype = np.dtype(dtype).type

indices = []
n = 0
N = 0
n = sum(map(space_local_size, spaces))
N = sum(map(space_global_size, spaces))

isets = []
i0 = comm.scan(n, op=MPI.SUM) - n
for space in spaces:
indices.append((n, n + space_local_size(space)))
n += space_local_size(space)
N += space_global_size(space)
i1 = i0 + space_local_size(space)
iset = PETSc.IS().createGeneral(
np.arange(i0, i1, dtype=PETSc.IntType),
comm=comm)
isets.append(iset)
i0 = i1

self._comm = comm
self._dtype = dtype
self._indices = tuple(indices)
self._isets = tuple(isets)
self._n = n
self._N = N

attach_destroy_finalizer(self, *isets)

@property
def comm(self):
return self._comm
Expand All @@ -110,10 +121,6 @@ def comm(self):
def dtype(self):
return self._dtype

@property
def indices(self):
return self._indices

@property
def local_size(self):
return self._n
Expand All @@ -123,34 +130,20 @@ def global_size(self):
return self._N

def from_petsc(self, y, X):
y_a = y.getArray(True)

if y_a.shape != (self.local_size,):
raise ValueError("Invalid shape")
if len(X) != len(self.indices):
if len(X) != len(self._isets):
raise ValueError("Invalid length")
for (i0, i1), x in zip(self.indices, X):
if not np.can_cast(y_a.dtype, var_dtype(x)):
raise ValueError("Invalid dtype")
if var_local_size(x) != i1 - i0:
raise ValueError("Invalid length")

for (i0, i1), x in zip(self.indices, X):
var_set_values(x, y_a[i0:i1])
for i, x in enumerate(X):
y_sub = y.getSubVector(self._isets[i])
var_from_petsc(x, y_sub)
y.restoreSubVector(self._isets[i], y_sub)

def to_petsc(self, x, Y):
if len(Y) != len(self.indices):
if len(Y) != len(self._isets):
raise ValueError("Invalid length")
for (i0, i1), y in zip(self.indices, Y):
if not np.can_cast(var_dtype(y), self.dtype):
raise ValueError("Invalid dtype")
if var_local_size(y) != i1 - i0:
raise ValueError("Invalid length")

x_a = np.zeros(self.local_size, dtype=self.dtype)
for (i0, i1), y in zip(self.indices, Y):
x_a[i0:i1] = var_get_values(y)
x.setArray(x_a)
for i, y in enumerate(Y):
x_sub = x.getSubVector(self._isets[i])
var_to_petsc(y, x_sub)
x.restoreSubVector(self._isets[i], x_sub)

def _new_petsc(self):
vec = PETSc.Vec().create(comm=self.comm)
Expand Down

0 comments on commit ed420a5

Please sign in to comment.