Skip to content

Commit

Permalink
fix example; add nosync
Browse files Browse the repository at this point in the history
  • Loading branch information
davidwendt committed Nov 1, 2023
1 parent 3f38612 commit 1640c16
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions cpp/src/text/bpe/byte_pair_encoding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ constexpr int block_size = 512;
* Launched as a thread per byte of the chars array.
* The output is non-zero offsets to locations of unpairable substrings.
* An unpairable substring does not exist in the given map and so will
* never be paired :-(. Fortunately, this can be used as an artificial
* never be paired. Fortunately, this can be used as an artificial
* boundary providing increased parallelism in the BPE kernel.
*
* @tparam MapRefType The type of the map finder object
Expand Down Expand Up @@ -171,12 +171,12 @@ __global__ void bpe_parallel_fn(cudf::column_device_view const d_strings,
auto min_rank = max_rank;

// store all the initial ranks for each pair
// every character but the first and last one will have a pair and a rank
// every character but the first one will have a initial rank
//
// Example:
// string: abcdefghij
// spaces: 1111111111
// ranks: *948516327*
// ranks: *948516327
for (auto itr = d_spaces + lane_idx; itr < end_spaces; itr += block_size) {
if (*itr == 0) { continue; } // skips any UTF-8 continuation bytes
// resolve pair and lookup its rank
Expand Down Expand Up @@ -367,22 +367,22 @@ std::unique_ptr<cudf::column> byte_pair_encoding(cudf::strings_column_view const
auto const mp_map = merge_pairs.impl->get_mp_table_ref(); // lookup table
auto const d_chars_span = cudf::device_span<char const>(d_input_chars, chars_size);
auto up_fn = bpe_unpairable_offsets_fn<decltype(mp_map)>{d_chars_span, first_offset, mp_map};
thrust::transform(rmm::exec_policy(stream), chars_begin, chars_end, d_up_offsets, up_fn);
thrust::transform(rmm::exec_policy_nosync(stream), chars_begin, chars_end, d_up_offsets, up_fn);
auto const up_end = // remove all but the unpairable offsets
thrust::remove(rmm::exec_policy(stream), d_up_offsets, d_up_offsets + chars_size, 0);
thrust::remove(rmm::exec_policy_nosync(stream), d_up_offsets, d_up_offsets + chars_size, 0);
auto const unpairables = thrust::distance(d_up_offsets, up_end); // number of unpairables

// new string boundaries created by combining unpairable offsets with the existing offsets
auto tmp_offsets = rmm::device_uvector<cudf::size_type>(unpairables + input.size() + 1, stream);
thrust::merge(rmm::exec_policy(stream),
thrust::merge(rmm::exec_policy_nosync(stream),
input.offsets_begin(),
input.offsets_end(),
d_up_offsets,
up_end,
tmp_offsets.begin());
// remove any adjacent duplicate offsets (i.e. empty or null rows)
auto const offsets_end =
thrust::unique(rmm::exec_policy(stream), tmp_offsets.begin(), tmp_offsets.end());
thrust::unique(rmm::exec_policy_nosync(stream), tmp_offsets.begin(), tmp_offsets.end());
auto const offsets_total =
static_cast<cudf::size_type>(thrust::distance(tmp_offsets.begin(), offsets_end));
tmp_offsets.resize(offsets_total, stream);
Expand Down Expand Up @@ -423,11 +423,11 @@ std::unique_ptr<cudf::column> byte_pair_encoding(cudf::strings_column_view const
return d_spaces[idx] > 0; // separator to be inserted here
};
auto const copy_end = thrust::copy_if(
rmm::exec_policy(stream), chars_begin + 1, chars_end, d_inserts, offsets_at_non_zero);
rmm::exec_policy_nosync(stream), chars_begin + 1, chars_end, d_inserts, offsets_at_non_zero);

// this will insert the single-byte separator into positions specified in d_inserts
auto const sep_char = thrust::constant_iterator<char>(separator.to_string(stream)[0]);
thrust::merge_by_key(rmm::exec_policy(stream),
thrust::merge_by_key(rmm::exec_policy_nosync(stream),
d_inserts, // where to insert separator byte
copy_end, //
chars_begin, // all indices
Expand Down

0 comments on commit 1640c16

Please sign in to comment.