Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OffloadPC (CUDA GPU) #3784

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions demos/helmholtz/helmholtz.txt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this file as discussed.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Main Stage 366614
Main Stage;firedrake 44369
Main Stage;firedrake;firedrake.solving.solve 86
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve 196
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve 140
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval 736
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute 212
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute;Parloop_Cells_wrap_form0_cell_integral 112
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute;Parloop_Cells_wrap_form0_cell_integral;pyop2.global_kernel.GlobalKernel.compile 415552
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;firedrake.tsfc_interface.compile_form 42597
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval 866
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute 149
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute;Parloop_Cells_wrap_form00_cell_integral 136
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute;Parloop_Cells_wrap_form00_cell_integral;pyop2.global_kernel.GlobalKernel.compile 407506
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__ 1771
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.tsfc_interface.compile_form 56423
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.tsfc_interface.compile_form;firedrake.formmanipulation.split_form 1907
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.solving_utils._SNESContext.__init__ 618
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__ 145
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.ufl_expr.action 4387
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.variational_solver.NonlinearVariationalProblem.__init__ 332
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.variational_solver.NonlinearVariationalProblem.__init__;firedrake.ufl_expr.adjoint 2798
Main Stage;firedrake;firedrake.function.Function.interpolate 342
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble 5644
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate 29
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute 298
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel 204
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel;pyop2.global_kernel.GlobalKernel.compile 682292
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.make_interpolator 40658
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write 2473
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate 303
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble 1080
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate 23
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute 328
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel 165
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel;pyop2.global_kernel.GlobalKernel.compile 663410
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.make_interpolator 55147
Main Stage;firedrake;firedrake.__init__ 495196
Main Stage;firedrake;firedrake.assemble.assemble 949
Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute 310
Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute;Parloop_Cells_wrap_form_cell_integral 95
Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute;Parloop_Cells_wrap_form_cell_integral;pyop2.global_kernel.GlobalKernel.compile 355507
Main Stage;firedrake;firedrake.assemble.assemble;firedrake.tsfc_interface.compile_form 20219
Main Stage;firedrake;CreateFunctionSpace 919
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace 79
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__ 165
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data 13
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__ 825
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__;FunctionSpaceData: CreateElement 1274
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__;firedrake.mesh.MeshTopology._facets 789
Main Stage;firedrake;CreateFunctionSpace;CreateMesh 147
Main Stage;firedrake;CreateFunctionSpace;CreateMesh;Mesh: numbering 376
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh 12
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh 11
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh 834
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh;CreateMesh 676
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh;DMPlexInterp 382
14 changes: 14 additions & 0 deletions firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
solver_parameters = solving_utils.set_defaults(solver_parameters,
A.arguments(),
ksp_defaults=self.DEFAULT_KSP_PARAMETERS)
# todo: add offload to solver parameters - how? prefix?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs addressing somehow. Not quite sure what is meant by this comment.


self.A = A
self.comm = A.comm
self._comm = internal_comm(self.comm, self)
Expand Down Expand Up @@ -163,6 +165,18 @@
else:
acc = x.dat.vec_wo

# if "cu" in self.A.petscmat.type: # todo: cuda or cu?
# with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution, dmhooks.add_hooks(self.ksp.dm, self):
# b_cu = PETSc.Vec()
# b_cu.createCUDAWithArrays(rhs)
# u = PETSc.Vec()
# u.createCUDAWithArrays(solution)
# self.ksp.solve(b_cu, u)
# u.getArray()

Check failure on line 176 in firedrake/linear_solver.py

View workflow job for this annotation

GitHub Actions / Run linter

W293

firedrake/linear_solver.py:176:1: W293 blank line contains whitespace
# else:
# instead: preconditioner
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete


with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution, dmhooks.add_hooks(self.ksp.dm, self):
self.ksp.solve(rhs, solution)

Expand Down
1 change: 1 addition & 0 deletions firedrake/preconditioners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from firedrake.preconditioners.fdm import * # noqa: F401
from firedrake.preconditioners.hiptmair import * # noqa: F401
from firedrake.preconditioners.facet_split import * # noqa: F401
from firedrake.preconditioners.offload import * # noqa: F401
117 changes: 117 additions & 0 deletions firedrake/preconditioners/offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from firedrake.preconditioners.base import PCBase
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace
from firedrake.petsc import PETSc
from firedrake.ufl_expr import TestFunction, TrialFunction
import firedrake.dmhooks as dmhooks
from firedrake.dmhooks import get_function_space
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably an unnecessary import.


__all__ = ("OffloadPC",)


class OffloadPC(PCBase):
"""Offload PC from CPU to GPU and back.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring could perhaps contain more detail about what is actually happening and even perhaps why one may wish to do this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. This is only for CUDA GPUs


Internally this makes a PETSc PC object that can be controlled by
options using the extra options prefix ``offload_``.
"""

_prefix = "offload_"

def initialize(self, pc):
with PETSc.Log.Event("Event: initialize offload"): #
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing #

Also should probably just decorate the method with @PETSc.Log.EventDecorator.

A, P = pc.getOperators()

outer_pc = pc
appctx = self.get_appctx(pc)
fcp = appctx.get("form_compiler_parameters")

V = get_function_space(pc.getDM())
if len(V) == 1:
V = FunctionSpace(V.mesh(), V.ufl_element())
else:
V = MixedFunctionSpace([V_ for V_ in V])
Comment on lines +28 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it strictly necessary to redefine V here? Seems a bit pointless.

test = TestFunction(V)
trial = TrialFunction(V)

(a, bcs) = self.form(pc, test, trial)

if P.type == "assembled":
context = P.getPythonContext()
# It only makes sense to preconditioner/invert a diagonal
# block in general. That's all we're going to allow.
Comment on lines +40 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete this as the error message makes this clear.

if not context.on_diag:
raise ValueError("Only makes sense to invert diagonal block")

prefix = pc.getOptionsPrefix()
options_prefix = prefix + self._prefix

mat_type = PETSc.Options().getString(options_prefix + "mat_type", "cusparse")

# Convert matrix to ajicusparse
with PETSc.Log.Event("Event: matrix offload"):
P_cu = P.convert(mat_type='aijcusparse') # todo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure about #todo comment here.

Also here are we doing a host-device copy? Nothing about the code makes that clear so perhaps a comment would be useful.


# Transfer nullspace
P_cu.setNullSpace(P.getNullSpace())
tnullsp = P.getTransposeNullSpace()
if tnullsp.handle != 0:
P_cu.setTransposeNullSpace(tnullsp)
P_cu.setNearNullSpace(P.getNearNullSpace())

# PC object set-up
pc = PETSc.PC().create(comm=outer_pc.comm)
pc.incrementTabLevel(1, parent=outer_pc)

# We set a DM and an appropriate SNESContext on the constructed PC
# so one can do e.g. multigrid or patch solves.
dm = outer_pc.getDM()
self._ctx_ref = self.new_snes_ctx(
outer_pc, a, bcs, mat_type,
fcp=fcp, options_prefix=options_prefix
)

pc.setDM(dm)
pc.setOptionsPrefix(options_prefix)
pc.setOperators(A, P_cu)
self.pc = pc
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref, save=False):
pc.setFromOptions()

def update(self, pc):
_, P = pc.getOperators()
_, P_cu = self.pc.getOperators()
P.copy(P_cu)

def form(self, pc, test, trial):
_, P = pc.getOperators()
if P.getType() == "python":
context = P.getPythonContext()
return (context.a, context.row_bcs)
else:
context = dmhooks.get_appctx(pc.getDM())
return (context.Jp or context.J, context._problem.bcs)

# Convert vectors to CUDA, solve and get solution on CPU back
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment should be inside the method (IMO).

def apply(self, pc, x, y):
with PETSc.Log.Event("Event: apply offload"): #
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decorate the method (also trailing #).

dm = pc.getDM()
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref):
with PETSc.Log.Event("Event: vectors offload"):
y_cu = PETSc.Vec() # begin
y_cu.createCUDAWithArrays(y)
Comment on lines +100 to +101
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y_cu = PETSc.Vec() # begin
y_cu.createCUDAWithArrays(y)
y_cu = PETSc.Vec().createCUDAWithArrays(y)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should work I think.

x_cu = PETSc.Vec()
x_cu.createCUDAWithArrays(x) # end
Comment on lines +102 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x_cu = PETSc.Vec()
x_cu.createCUDAWithArrays(x) # end
x_cu = PETSc.Vec().createCUDAWithArrays(x)

with PETSc.Log.Event("Event: solve"):
self.pc.apply(x_cu, y_cu) #
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.pc.apply(x_cu, y_cu) #
self.pc.apply(x_cu, y_cu)

with PETSc.Log.Event("Event: vectors copy back"):
y.copy(y_cu) #
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
y.copy(y_cu) #
y.copy(y_cu)


def applyTranspose(self, pc, X, Y):
raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have a useful error message? Not sure what the usual approach is here.


def view(self, pc, viewer=None):
super().view(pc, viewer)
print("viewing PC")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("viewing PC")

if hasattr(self, "pc"):
viewer.printfASCII("PC to solve on GPU\n")
self.pc.view(viewer)
21 changes: 12 additions & 9 deletions firedrake/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,18 @@
options_prefix=options_prefix)
if isinstance(x, firedrake.Vector):
x = x.function
# linear MG doesn't need RHS, supply zero.
lvp = vs.LinearVariationalProblem(a=A.a, L=0, u=x, bcs=A.bcs)
mat_type = A.mat_type
appctx = solver_parameters.get("appctx", {})
ctx = solving_utils._SNESContext(lvp,
mat_type=mat_type,
pmat_type=mat_type,
appctx=appctx,
options_prefix=options_prefix)
if not isinstance(A, firedrake.matrix.AssembledMatrix):
# linear MG doesn't need RHS, supply zero.
lvp = vs.LinearVariationalProblem(a=A.a, L=0, u=x, bcs=A.bcs)
mat_type = A.mat_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pointless line to have, delete and use A.mat_type below

appctx = solver_parameters.get("appctx", {})
ctx = solving_utils._SNESContext(lvp,
mat_type=mat_type,

Check failure on line 261 in firedrake/solving.py

View workflow job for this annotation

GitHub Actions / Run linter

E128

firedrake/solving.py:261:41: E128 continuation line under-indented for visual indent
pmat_type=mat_type,

Check failure on line 262 in firedrake/solving.py

View workflow job for this annotation

GitHub Actions / Run linter

E128

firedrake/solving.py:262:41: E128 continuation line under-indented for visual indent
appctx=appctx,

Check failure on line 263 in firedrake/solving.py

View workflow job for this annotation

GitHub Actions / Run linter

E128

firedrake/solving.py:263:41: E128 continuation line under-indented for visual indent
options_prefix=options_prefix)

Check failure on line 264 in firedrake/solving.py

View workflow job for this annotation

GitHub Actions / Run linter

E128

firedrake/solving.py:264:41: E128 continuation line under-indented for visual indent
else:
ctx = None
dm = solver.ksp.dm

with dmhooks.add_hooks(dm, solver, appctx=ctx):
Expand Down
Loading