-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OffloadPC (CUDA GPU) #3784
base: master
Are you sure you want to change the base?
OffloadPC (CUDA GPU) #3784
Changes from all commits
07ba27e
00d5d75
2aae948
6ed6495
5f75618
500b7d3
2778be5
7f8d3f3
6046bab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
Main Stage 366614 | ||
Main Stage;firedrake 44369 | ||
Main Stage;firedrake;firedrake.solving.solve 86 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve 196 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve 140 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval 736 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute 212 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute;Parloop_Cells_wrap_form0_cell_integral 112 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;ParLoopExecute;Parloop_Cells_wrap_form0_cell_integral;pyop2.global_kernel.GlobalKernel.compile 415552 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESFunctionEval;firedrake.tsfc_interface.compile_form 42597 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval 866 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute 149 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute;Parloop_Cells_wrap_form00_cell_integral 136 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.solve;SNESSolve;SNESJacobianEval;ParLoopExecute;Parloop_Cells_wrap_form00_cell_integral;pyop2.global_kernel.GlobalKernel.compile 407506 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__ 1771 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.tsfc_interface.compile_form 56423 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.tsfc_interface.compile_form;firedrake.formmanipulation.split_form 1907 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.NonlinearVariationalSolver.__init__;firedrake.solving_utils._SNESContext.__init__ 618 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__ 145 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.ufl_expr.action 4387 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.variational_solver.NonlinearVariationalProblem.__init__ 332 | ||
Main Stage;firedrake;firedrake.solving.solve;firedrake.variational_solver.LinearVariationalProblem.__init__;firedrake.variational_solver.NonlinearVariationalProblem.__init__;firedrake.ufl_expr.adjoint 2798 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate 342 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble 5644 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate 29 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute 298 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel 204 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel;pyop2.global_kernel.GlobalKernel.compile 682292 | ||
Main Stage;firedrake;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.make_interpolator 40658 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write 2473 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate 303 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble 1080 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate 23 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute 328 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel 165 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.SameMeshInterpolator._interpolate;ParLoopExecute;Parloop_Cells_wrap_expression_kernel;pyop2.global_kernel.GlobalKernel.compile 663410 | ||
Main Stage;firedrake;firedrake.output.vtk_output.VTKFile.write;firedrake.function.Function.interpolate;firedrake.assemble.assemble;firedrake.interpolation.make_interpolator 55147 | ||
Main Stage;firedrake;firedrake.__init__ 495196 | ||
Main Stage;firedrake;firedrake.assemble.assemble 949 | ||
Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute 310 | ||
Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute;Parloop_Cells_wrap_form_cell_integral 95 | ||
Main Stage;firedrake;firedrake.assemble.assemble;ParLoopExecute;Parloop_Cells_wrap_form_cell_integral;pyop2.global_kernel.GlobalKernel.compile 355507 | ||
Main Stage;firedrake;firedrake.assemble.assemble;firedrake.tsfc_interface.compile_form 20219 | ||
Main Stage;firedrake;CreateFunctionSpace 919 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace 79 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__ 165 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data 13 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__ 825 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__;FunctionSpaceData: CreateElement 1274 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateFunctionSpace;firedrake.functionspaceimpl.FunctionSpace.__init__;firedrake.functionspacedata.get_shared_data;firedrake.functionspacedata.FunctionSpaceData.__init__;firedrake.mesh.MeshTopology._facets 789 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateMesh 147 | ||
Main Stage;firedrake;CreateFunctionSpace;CreateMesh;Mesh: numbering 376 | ||
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh 12 | ||
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh 11 | ||
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh 834 | ||
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh;CreateMesh 676 | ||
Main Stage;firedrake;firedrake.utility_meshes.UnitSquareMesh;firedrake.utility_meshes.SquareMesh;firedrake.utility_meshes.RectangleMesh;DMPlexInterp 382 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,8 @@ | |
solver_parameters = solving_utils.set_defaults(solver_parameters, | ||
A.arguments(), | ||
ksp_defaults=self.DEFAULT_KSP_PARAMETERS) | ||
# todo: add offload to solver parameters - how? prefix? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs addressing somehow. Not quite sure what is meant by this comment. |
||
|
||
self.A = A | ||
self.comm = A.comm | ||
self._comm = internal_comm(self.comm, self) | ||
|
@@ -163,6 +165,18 @@ | |
else: | ||
acc = x.dat.vec_wo | ||
|
||
# if "cu" in self.A.petscmat.type: # todo: cuda or cu? | ||
# with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution, dmhooks.add_hooks(self.ksp.dm, self): | ||
# b_cu = PETSc.Vec() | ||
# b_cu.createCUDAWithArrays(rhs) | ||
# u = PETSc.Vec() | ||
# u.createCUDAWithArrays(solution) | ||
# self.ksp.solve(b_cu, u) | ||
# u.getArray() | ||
|
||
# else: | ||
# instead: preconditioner | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete |
||
|
||
with self.inserted_options(), b.dat.vec_ro as rhs, acc as solution, dmhooks.add_hooks(self.ksp.dm, self): | ||
self.ksp.solve(rhs, solution) | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,117 @@ | ||||||||
from firedrake.preconditioners.base import PCBase | ||||||||
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace | ||||||||
from firedrake.petsc import PETSc | ||||||||
from firedrake.ufl_expr import TestFunction, TrialFunction | ||||||||
import firedrake.dmhooks as dmhooks | ||||||||
from firedrake.dmhooks import get_function_space | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably an unnecessary import. |
||||||||
|
||||||||
__all__ = ("OffloadPC",) | ||||||||
|
||||||||
|
||||||||
class OffloadPC(PCBase): | ||||||||
"""Offload PC from CPU to GPU and back. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This docstring could perhaps contain more detail about what is actually happening and even perhaps why one may wish to do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E.g. This is only for CUDA GPUs |
||||||||
|
||||||||
Internally this makes a PETSc PC object that can be controlled by | ||||||||
options using the extra options prefix ``offload_``. | ||||||||
""" | ||||||||
|
||||||||
_prefix = "offload_" | ||||||||
|
||||||||
def initialize(self, pc): | ||||||||
with PETSc.Log.Event("Event: initialize offload"): # | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. trailing Also should probably just decorate the method with |
||||||||
A, P = pc.getOperators() | ||||||||
|
||||||||
outer_pc = pc | ||||||||
appctx = self.get_appctx(pc) | ||||||||
fcp = appctx.get("form_compiler_parameters") | ||||||||
|
||||||||
V = get_function_space(pc.getDM()) | ||||||||
if len(V) == 1: | ||||||||
V = FunctionSpace(V.mesh(), V.ufl_element()) | ||||||||
else: | ||||||||
V = MixedFunctionSpace([V_ for V_ in V]) | ||||||||
Comment on lines
+28
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it strictly necessary to redefine |
||||||||
test = TestFunction(V) | ||||||||
trial = TrialFunction(V) | ||||||||
|
||||||||
(a, bcs) = self.form(pc, test, trial) | ||||||||
|
||||||||
if P.type == "assembled": | ||||||||
context = P.getPythonContext() | ||||||||
# It only makes sense to preconditioner/invert a diagonal | ||||||||
# block in general. That's all we're going to allow. | ||||||||
Comment on lines
+40
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can delete this as the error message makes this clear. |
||||||||
if not context.on_diag: | ||||||||
raise ValueError("Only makes sense to invert diagonal block") | ||||||||
|
||||||||
prefix = pc.getOptionsPrefix() | ||||||||
options_prefix = prefix + self._prefix | ||||||||
|
||||||||
mat_type = PETSc.Options().getString(options_prefix + "mat_type", "cusparse") | ||||||||
|
||||||||
# Convert matrix to ajicusparse | ||||||||
with PETSc.Log.Event("Event: matrix offload"): | ||||||||
P_cu = P.convert(mat_type='aijcusparse') # todo | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsure about Also here are we doing a host-device copy? Nothing about the code makes that clear so perhaps a comment would be useful. |
||||||||
|
||||||||
# Transfer nullspace | ||||||||
P_cu.setNullSpace(P.getNullSpace()) | ||||||||
tnullsp = P.getTransposeNullSpace() | ||||||||
if tnullsp.handle != 0: | ||||||||
P_cu.setTransposeNullSpace(tnullsp) | ||||||||
P_cu.setNearNullSpace(P.getNearNullSpace()) | ||||||||
|
||||||||
# PC object set-up | ||||||||
pc = PETSc.PC().create(comm=outer_pc.comm) | ||||||||
pc.incrementTabLevel(1, parent=outer_pc) | ||||||||
|
||||||||
# We set a DM and an appropriate SNESContext on the constructed PC | ||||||||
# so one can do e.g. multigrid or patch solves. | ||||||||
dm = outer_pc.getDM() | ||||||||
self._ctx_ref = self.new_snes_ctx( | ||||||||
outer_pc, a, bcs, mat_type, | ||||||||
fcp=fcp, options_prefix=options_prefix | ||||||||
) | ||||||||
|
||||||||
pc.setDM(dm) | ||||||||
pc.setOptionsPrefix(options_prefix) | ||||||||
pc.setOperators(A, P_cu) | ||||||||
self.pc = pc | ||||||||
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref, save=False): | ||||||||
pc.setFromOptions() | ||||||||
|
||||||||
def update(self, pc): | ||||||||
_, P = pc.getOperators() | ||||||||
_, P_cu = self.pc.getOperators() | ||||||||
P.copy(P_cu) | ||||||||
|
||||||||
def form(self, pc, test, trial): | ||||||||
_, P = pc.getOperators() | ||||||||
if P.getType() == "python": | ||||||||
context = P.getPythonContext() | ||||||||
return (context.a, context.row_bcs) | ||||||||
else: | ||||||||
context = dmhooks.get_appctx(pc.getDM()) | ||||||||
return (context.Jp or context.J, context._problem.bcs) | ||||||||
|
||||||||
# Convert vectors to CUDA, solve and get solution on CPU back | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment should be inside the method (IMO). |
||||||||
def apply(self, pc, x, y): | ||||||||
with PETSc.Log.Event("Event: apply offload"): # | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Decorate the method (also trailing |
||||||||
dm = pc.getDM() | ||||||||
with dmhooks.add_hooks(dm, self, appctx=self._ctx_ref): | ||||||||
with PETSc.Log.Event("Event: vectors offload"): | ||||||||
y_cu = PETSc.Vec() # begin | ||||||||
y_cu.createCUDAWithArrays(y) | ||||||||
Comment on lines
+100
to
+101
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should work I think. |
||||||||
x_cu = PETSc.Vec() | ||||||||
x_cu.createCUDAWithArrays(x) # end | ||||||||
Comment on lines
+102
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
with PETSc.Log.Event("Event: solve"): | ||||||||
self.pc.apply(x_cu, y_cu) # | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
with PETSc.Log.Event("Event: vectors copy back"): | ||||||||
y.copy(y_cu) # | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
def applyTranspose(self, pc, X, Y): | ||||||||
raise NotImplementedError | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe have a useful error message? Not sure what the usual approach is here. |
||||||||
|
||||||||
def view(self, pc, viewer=None): | ||||||||
super().view(pc, viewer) | ||||||||
print("viewing PC") | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
if hasattr(self, "pc"): | ||||||||
viewer.printfASCII("PC to solve on GPU\n") | ||||||||
self.pc.view(viewer) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -252,15 +252,18 @@ | |
options_prefix=options_prefix) | ||
if isinstance(x, firedrake.Vector): | ||
x = x.function | ||
# linear MG doesn't need RHS, supply zero. | ||
lvp = vs.LinearVariationalProblem(a=A.a, L=0, u=x, bcs=A.bcs) | ||
mat_type = A.mat_type | ||
appctx = solver_parameters.get("appctx", {}) | ||
ctx = solving_utils._SNESContext(lvp, | ||
mat_type=mat_type, | ||
pmat_type=mat_type, | ||
appctx=appctx, | ||
options_prefix=options_prefix) | ||
if not isinstance(A, firedrake.matrix.AssembledMatrix): | ||
# linear MG doesn't need RHS, supply zero. | ||
lvp = vs.LinearVariationalProblem(a=A.a, L=0, u=x, bcs=A.bcs) | ||
mat_type = A.mat_type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pointless line to have, delete and use |
||
appctx = solver_parameters.get("appctx", {}) | ||
ctx = solving_utils._SNESContext(lvp, | ||
mat_type=mat_type, | ||
pmat_type=mat_type, | ||
appctx=appctx, | ||
options_prefix=options_prefix) | ||
else: | ||
ctx = None | ||
dm = solver.ksp.dm | ||
|
||
with dmhooks.add_hooks(dm, solver, appctx=ctx): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete this file as discussed.