Skip to content

Commit

Permalink
Merge pull request #2148 from mfem/device-prolongation-fix
Browse files Browse the repository at this point in the history
Fix a few issues in DeviceConformingProlongationOperator
  • Loading branch information
tzanio authored Apr 13, 2021
2 parents 9dbd5e5 + a380039 commit f4a99fa
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions fem/pfespace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,7 @@ DeviceConformingProlongationOperator::DeviceConformingProlongationOperator(
const int nb_connections = nbr_ldof.Size_of_connections();
ext_ldof.SetSize(nb_connections);
ext_ldof.CopyFrom(nbr_ldof.GetJ());
ext_ldof.GetMemory().UseDevice(true);
ext_buf.SetSize(nb_connections);
ext_buf.UseDevice(true);
ext_buf_offsets = nbr_ldof.GetIMemory();
Expand Down Expand Up @@ -3256,61 +3257,68 @@ DeviceConformingProlongationOperator::DeviceConformingProlongationOperator(
MFEM_ASSERT(pfes.GetRestrictionMatrix()->Height() == pfes.GetTrueVSize(), "");
}

static void ExtractSubVector(const int N,
const Array<int> &indices,
static void ExtractSubVector(const Array<int> &indices,
const Vector &in, Vector &out)
{
MFEM_ASSERT(indices.Size() == out.Size(), "incompatible sizes!");
auto y = out.Write();
const auto x = in.Read();
const auto I = indices.Read();
MFEM_FORALL(i, N, y[i] = x[I[i]];); // indices can be repeated
MFEM_FORALL(i, indices.Size(), y[i] = x[I[i]];); // indices can be repeated
}

void DeviceConformingProlongationOperator::BcastBeginCopy(
const Vector &x) const
{
// shr_buf[i] = src[shr_ltdof[i]]
if (shr_ltdof.Size() == 0) { return; }
ExtractSubVector(shr_ltdof.Size(), shr_ltdof, x, shr_buf);
ExtractSubVector(shr_ltdof, x, shr_buf);
// If the above kernel is executed asynchronously, we should wait for it to
// complete
if (mpi_gpu_aware) { MFEM_STREAM_SYNC; }
}

static void SetSubVector(const int N,
const Array<int> &indices,
static void SetSubVector(const Array<int> &indices,
const Vector &in, Vector &out)
{
auto y = out.Write();
MFEM_ASSERT(indices.Size() == in.Size(), "incompatible sizes!");
// Use ReadWrite() since we modify only a subset of the indices:
auto y = out.ReadWrite();
const auto x = in.Read();
const auto I = indices.Read();
MFEM_FORALL(i, N, y[I[i]] = x[i];);
MFEM_FORALL(i, indices.Size(), y[I[i]] = x[i];);
}

void DeviceConformingProlongationOperator::BcastLocalCopy(
const Vector &x, Vector &y) const
{
// dst[ltdof_ldof[i]] = src[i]
if (ltdof_ldof.Size() == 0) { return; }
SetSubVector(ltdof_ldof.Size(), ltdof_ldof, x, y);
SetSubVector(ltdof_ldof, x, y);
}

void DeviceConformingProlongationOperator::BcastEndCopy(
Vector &y) const
{
// dst[ext_ldof[i]] = ext_buf[i]
if (ext_ldof.Size() == 0) { return; }
SetSubVector(ext_ldof.Size(), ext_ldof, ext_buf, y);
SetSubVector(ext_ldof, ext_buf, y);
}

void DeviceConformingProlongationOperator::Mult(const Vector &x,
Vector &y) const
{
const GroupTopology &gtopo = gc.GetGroupTopology();
int req_counter = 0;
// Make sure 'y' is marked as valid on device and for use on device.
// This ensures that there is no unnecessary host to device copy when the
// input 'y' is valid on host (in 'y.SetSubVector(ext_ldof, 0.0)' when local
// is true) or BcastLocalCopy (when local is false).
y.Write();
if (local)
{
y = 0.0;
// done on device since we've marked ext_ldof for use on device:
y.SetSubVector(ext_ldof, 0.0);
}
else
{
Expand Down Expand Up @@ -3357,7 +3365,7 @@ void DeviceConformingProlongationOperator::ReduceBeginCopy(
{
// ext_buf[i] = src[ext_ldof[i]]
if (ext_ldof.Size() == 0) { return; }
ExtractSubVector(ext_ldof.Size(), ext_ldof, x, ext_buf);
ExtractSubVector(ext_ldof, x, ext_buf);
// If the above kernel is executed asynchronously, we should wait for it to
// complete
if (mpi_gpu_aware) { MFEM_STREAM_SYNC; }
Expand All @@ -3368,22 +3376,21 @@ void DeviceConformingProlongationOperator::ReduceLocalCopy(
{
// dst[i] = src[ltdof_ldof[i]]
if (ltdof_ldof.Size() == 0) { return; }
ExtractSubVector(ltdof_ldof.Size(), ltdof_ldof, x, y);
ExtractSubVector(ltdof_ldof, x, y);
}

static void AddSubVector(const int num_unique_dst_indices,
const Array<int> &unique_dst_indices,
static void AddSubVector(const Array<int> &unique_dst_indices,
const Array<int> &unique_to_src_offsets,
const Array<int> &unique_to_src_indices,
const Vector &src,
Vector &dst)
{
auto y = dst.Write();
auto y = dst.ReadWrite();
const auto x = src.Read();
const auto DST_I = unique_dst_indices.Read();
const auto SRC_O = unique_to_src_offsets.Read();
const auto SRC_I = unique_to_src_indices.Read();
MFEM_FORALL(i, num_unique_dst_indices,
MFEM_FORALL(i, unique_dst_indices.Size(),
{
const int dst_idx = DST_I[i];
double sum = y[dst_idx];
Expand All @@ -3396,9 +3403,8 @@ static void AddSubVector(const int num_unique_dst_indices,
void DeviceConformingProlongationOperator::ReduceEndAssemble(Vector &y) const
{
// dst[shr_ltdof[i]] += shr_buf[i]
const int unq_ltdof_size = unq_ltdof.Size();
if (unq_ltdof_size == 0) { return; }
AddSubVector(unq_ltdof_size, unq_ltdof, unq_shr_i, unq_shr_j, shr_buf, y);
if (unq_ltdof.Size() == 0) { return; }
AddSubVector(unq_ltdof, unq_shr_i, unq_shr_j, shr_buf, y);
}

void DeviceConformingProlongationOperator::MultTranspose(const Vector &x,
Expand Down

0 comments on commit f4a99fa

Please sign in to comment.