Skip to content

Commit

Permalink
Making terrain list GPU aware
Browse files Browse the repository at this point in the history
  • Loading branch information
hgopalan committed Dec 10, 2024
1 parent 90da8df commit af6f7f5
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions Source/SourceTerms/ERF_TerrainDrag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,22 @@ TerrainDrag::define_terrain_blank_field(
m_terrain_blank.reset();
m_terrain_blank = std::make_unique<MultiFab>(ba, dm, 1, 1);
m_terrain_blank->setVal(0.);

const auto xterrain_size = m_x_terrain.size();
amrex::Gpu::DeviceVector<amrex::Real> d_xterrain(xterrain_size);
amrex::Gpu::DeviceVector<amrex::Real> d_yterrain(xterrain_size);
amrex::Gpu::DeviceVector<amrex::Real> d_zterrain(xterrain_size);
amrex::Gpu::copy(
amrex::Gpu::hostToDevice, m_x_terrain.begin(), m_x_terrain.end(),
d_xterrain.begin());
amrex::Gpu::copy(
amrex::Gpu::hostToDevice, m_y_terrain.begin(), m_y_terrain.end(),
d_yterrain.begin());
amrex::Gpu::copy(
amrex::Gpu::hostToDevice, m_height_terrain.begin(), m_height_terrain.end(),
d_zterrain.begin());
const auto* xterrain_ptr = d_xterrain.data();
const auto* yterrain_ptr = d_yterrain.data();
const auto* zterrain_ptr = d_zterrain.data();
// Set the terrain blank data
for (MFIter mfi(*m_terrain_blank); mfi.isValid(); ++mfi) {
Box gtbx = mfi.growntilebox();
Expand All @@ -47,10 +62,10 @@ TerrainDrag::define_terrain_blank_field(
(z_phys_nd) ? z_phys_nd->const_array(mfi) : Array4<const Real>{};
ParallelFor(gtbx, [=] AMREX_GPU_DEVICE(int i, int j, int k) noexcept {
// Loop over terrain points
for (unsigned ii = 0; ii < m_x_terrain.size(); ++ii) {
Real ht = m_height_terrain[ii];
Real xt = m_x_terrain[ii];
Real yt = m_y_terrain[ii];
for (unsigned ii = 0; ii < xterrain_size; ++ii) {
Real ht = zterrain_ptr[ii];
Real xt = xterrain_ptr[ii];
Real yt = yterrain_ptr[ii];
// Physical positions of cell-centers
const Real x = prob_lo[0] + (i + 0.5) * dx[0];
const Real y = prob_lo[1] + (j + 0.5) * dx[1];
Expand Down

0 comments on commit af6f7f5

Please sign in to comment.