Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nonuniform tile sizes to more testers & fix some related issues #143

Merged
merged 19 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions include/slate/BaseMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -904,8 +904,36 @@ BaseMatrix<scalar_t>::BaseMatrix(
++nt_;
}

auto actTileMb = inTileMb;
if (last_mb_ != actTileMb( mt_-1 )) {
// Pass variables into lambda w/out reference to `this`.
int64_t mt_1 = mt_-1;
int64_t last_mb = last_mb_;
actTileMb = [inTileMb, last_mb, mt_1](int64_t i) {
if (i == mt_1) {
return last_mb;
}
else {
return inTileMb( i );
}
};
}
auto actTileNb = inTileNb;
if (last_nb_ != actTileNb( nt_-1 )) {
int64_t nt_1 = nt_-1;
int64_t last_nb = last_nb_;
actTileNb = [inTileNb, last_nb, nt_1](int64_t j) {
if (j == nt_1) {
return last_nb;
}
else {
return inTileNb( j );
}
};
}

storage_ = std::make_shared< MatrixStorage< scalar_t > >(
mt_, nt_, inTileMb, inTileNb,
mt_, nt_, actTileMb, actTileNb,
inTileRank, inTileDevice, mpi_comm );

slate_mpi_call(
Expand Down Expand Up @@ -2682,9 +2710,14 @@ void BaseMatrix<scalar_t>::tileGet(int64_t i, int64_t j, int dst_device,
+ " -> " + std::to_string(dst_device));
}

target_layout = layout == LayoutConvert::None ?
src_tile->layout() :
Layout(layout);
target_layout = layout == LayoutConvert::None
? src_tile->layout()
: Layout(layout);
}
else {
target_layout = layout == LayoutConvert::None
? tile_node[dst_device]->layout()
: Layout(layout);
}

if (! tile_node.existsOn(dst_device)) {
Expand All @@ -2700,8 +2733,7 @@ void BaseMatrix<scalar_t>::tileGet(int64_t i, int64_t j, int dst_device,
tileCopyDataLayout( src_tile, dst_tile, target_layout, async );

dst_tile->state(MOSI::Shared);
if (src_tile->stateOn(MOSI::Modified))
src_tile->state(MOSI::Shared);
src_tile->state(MOSI::Shared); // src was either shared or modified
}
if (modify) {
tileModified(i, j, dst_device);
Expand Down Expand Up @@ -3459,6 +3491,7 @@ void BaseMatrix<scalar_t>::tileLayoutConvert(
batch_count =
std::max(batch_count, int64_t(bucket->second.first.size()));
}
allocateBatchArrays( batch_count, 1 );
mgates3 marked this conversation as resolved.
Show resolved Hide resolved

lapack::Queue* queue = comm_queue(device);

Expand All @@ -3469,7 +3502,7 @@ void BaseMatrix<scalar_t>::tileLayoutConvert(
batch_count = bucket->second.first.size();

scalar_t** array_dev = this->array_device(device);
scalar_t** work_array_dev = this->array_device(device) + batch_count;
scalar_t** work_array_dev = array_dev + batch_count;

assert(array_dev != nullptr);
assert(work_array_dev != nullptr);
Expand Down
17 changes: 8 additions & 9 deletions include/slate/internal/MatrixStorage.hh
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ public:
void eraseOn(int device)
{
slate_assert(device >= -1 && device+1 < int(tiles_.size()));
if (tiles_[device+1] != nullptr) {
tiles_[device+1]->state(MOSI::Invalid);
delete tiles_[device+1];
tiles_[device+1] = nullptr;
auto tile = tiles_[device+1];
tiles_[device+1] = nullptr;
if (tile != nullptr) {
delete tile;
--num_instances_;
}
}
Expand Down Expand Up @@ -644,11 +644,10 @@ void MatrixStorage<scalar_t>::allocateBatchArrays(
// Free device arrays.
blas::device_free(array_dev_[i][device], *queue);

// Free queues.
delete compute_queues_[i][device];

// Allocate queues.
compute_queues_[ i ][ device ] = new lapack::Queue( device );
if (compute_queues_[ i ][ device ] == nullptr) {
// Allocate queues.
compute_queues_[ i ][ device ] = new lapack::Queue( device );
}

// Allocate host arrays;
array_host_[i][device]
Expand Down
9 changes: 4 additions & 5 deletions src/getrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ int64_t getrf(
bcast_list_A.push_back({i, k, {A.sub(i, i, k+1, A_nt-1)}});
}
A.template listBcast<target>(
bcast_list_A, Layout::ColMajor, tag_k );
bcast_list_A, target_layout, tag_k );

// Root broadcasts the pivot to all ranks.
// todo: Panel ranks send the pivots to the right.
Expand Down Expand Up @@ -139,7 +139,7 @@ int64_t getrf(
internal::trsm<target>(
Side::Left,
one, std::move( Tkk ), A.sub(k, k, j, j),
priority_1, Layout::ColMajor, queue_jk1 );
priority_1, target_layout, queue_jk1 );

// send A(k, j) across column A(k+1:mt-1, j)
// todo: trsm still operates in ColMajor
Expand Down Expand Up @@ -196,17 +196,16 @@ int64_t getrf(
Side::Left,
one, std::move( Tkk ),
A.sub(k, k, k+1+lookahead, A_nt-1),
priority_0, Layout::ColMajor, queue_1 );
priority_0, target_layout, queue_1 );

// send A(k, kl+1:A_nt-1) across A(k+1:mt-1, kl+1:nt-1)
BcastList bcast_list_A;
for (int64_t j = k+1+lookahead; j < A_nt; ++j) {
// send A(k, j) across column A(k+1:mt-1, j)
bcast_list_A.push_back({k, j, {A.sub(k+1, A_mt-1, j, j)}});
}
// todo: trsm still operates in ColMajor
A.template listBcast<target>(
bcast_list_A, Layout::ColMajor, tag_kl1);
bcast_list_A, target_layout, tag_kl1);

// A(k+1:mt-1, kl+1:nt-1) -= A(k+1:mt-1, k) * A(k, kl+1:nt-1)
internal::gemm<target>(
Expand Down
12 changes: 6 additions & 6 deletions src/getrf_tntpiv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ int64_t getrf_tntpiv(
Side::Right,
one, std::move(Tkk),
A.sub( k+1, A_mt-1, k, k ),
priority_1, Layout::ColMajor, queue_0 );
priority_1, target_layout, queue_0 );
}

#pragma omp task depend(inout:column[k]) \
Expand All @@ -238,7 +238,7 @@ int64_t getrf_tntpiv(
bcast_list.push_back({i, k, {A.sub(i, i, k+1, A_nt-1)}, tag});
}
A.template listBcastMT<target>(
bcast_list, Layout::ColMajor );
bcast_list, target_layout );
}

// update lookahead column(s), high priority
Expand All @@ -262,13 +262,13 @@ int64_t getrf_tntpiv(
internal::trsm<target>(
Side::Left,
one, std::move( Tkk ), A.sub( k, k, j, j ),
priority_1, Layout::ColMajor, queue_jk1 );
priority_1, target_layout, queue_jk1 );

// send A(k, j) across column A(k+1:mt-1, j)
// todo: trsm still operates in ColMajor
A.tileBcast(
k, j, A.sub( k+1, A_mt-1, j, j ),
Layout::ColMajor, tag_j );
target_layout, tag_j );
}

#pragma omp task depend(in:column[k]) \
Expand Down Expand Up @@ -306,7 +306,7 @@ int64_t getrf_tntpiv(
Side::Left,
one, std::move( Tkk ),
A.sub( k, k, k+1+lookahead, A_nt-1 ),
priority_0, Layout::ColMajor, queue_1 );
priority_0, target_layout, queue_1 );
}

#pragma omp task depend(inout:column[k+1+lookahead]) \
Expand All @@ -322,7 +322,7 @@ int64_t getrf_tntpiv(
bcast_list.push_back({k, j, {A.sub(k+1, A_mt-1, j, j)}, tag});
}
A.template listBcastMT<target>(
bcast_list, Layout::ColMajor);
bcast_list, target_layout);

}

Expand Down
31 changes: 31 additions & 0 deletions test/grid_utils.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define SLATE_GRID_UTILS_HH

#include "slate/slate.hh"
#include "scalapack_wrappers.hh"

#include <stdint.h>

Expand Down Expand Up @@ -56,4 +57,34 @@ inline void gridinfo(
gridinfo( mpi_rank, slate::GridOrder::Col, p, q, my_row, my_col );
}

//------------------------------------------------------------------------------
#ifdef SLATE_HAVE_SCALAPACK
// Sets up a BLACS context with the given grid
inline void create_ScaLAPACK_context( slate::GridOrder grid_order,
int p, int q,
blas_int* ictxt )
{
// BLACS/MPI variables
int mpi_rank, myrow, mycol;
blas_int p_, q_, myrow_, mycol_;
blas_int mpi_rank_ = 0, nprocs = 1;

// initialize BLACS and ScaLAPACK
MPI_Comm_rank( MPI_COMM_WORLD, &mpi_rank );
Cblacs_pinfo( &mpi_rank_, &nprocs );
slate_assert( mpi_rank_ == mpi_rank );

Cblacs_get( -1, 0, ictxt );

slate_assert( p*q <= nprocs );
Cblacs_gridinit( ictxt, grid_order2str( grid_order ), p, q );
gridinfo( mpi_rank, grid_order, p, q, &myrow, &mycol );
Cblacs_gridinfo( *ictxt, &p_, &q_, &myrow_, &mycol_ );
slate_assert( p == p_ );
slate_assert( q == q_ );
slate_assert( myrow == myrow_ );
slate_assert( mycol == mycol_ );
}
#endif // SLATE_HAVE_SCALAPACK

#endif // SLATE_GRID_UTILS_HH
Loading