Skip to content

Commit

Permalink
add logits
Browse files Browse the repository at this point in the history
  • Loading branch information
prashanth058 committed Jun 17, 2024
1 parent 8c6715e commit 884b8df
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 12 deletions.
2 changes: 2 additions & 0 deletions include/ctranslate2/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace ctranslate2 {
std::vector<std::vector<size_t>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<float> logits;
// (max_decoding_steps)
};

struct DecodingStepResult {
Expand Down
6 changes: 5 additions & 1 deletion include/ctranslate2/translation.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ namespace ctranslate2 {
std::vector<std::vector<std::string>> hypotheses;
std::vector<float> scores;
std::vector<std::vector<std::vector<float>>> attention;
std::vector<float> logits;

TranslationResult(std::vector<std::vector<std::string>> hypotheses_)
: hypotheses(std::move(hypotheses_))
Expand All @@ -95,10 +96,12 @@ namespace ctranslate2 {

TranslationResult(std::vector<std::vector<std::string>> hypotheses_,
std::vector<float> scores_,
std::vector<std::vector<std::vector<float>>> attention_)
std::vector<std::vector<std::vector<float>>> attention_,
std::vector<float> logits_)
: hypotheses(std::move(hypotheses_))
, scores(std::move(scores_))
, attention(std::move(attention_))
, logits(std::move(logits_))
{
}

Expand All @@ -109,6 +112,7 @@ namespace ctranslate2 {
: hypotheses(num_hypotheses)
, scores(with_score ? num_hypotheses : 0, static_cast<float>(0))
, attention(with_attention ? num_hypotheses : 0)
, logits(with_score ? num_hypotheses : 0)
{
}

Expand Down
7 changes: 6 additions & 1 deletion python/cpp/translation_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ namespace ctranslate2 {
"Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).")
.def_readonly("attention", &TranslationResult::attention,
"Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).")
.def_readonly("logits", &TranslationResult::logits,
"Logits for each decoding step")

.def("__repr__", [](const TranslationResult& result) {
return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses)))
+ ", scores=" + std::string(py::repr(py::cast(result.scores)))
+ ", attention=" + std::string(py::repr(py::cast(result.attention)))
+ ", logits=" + std::string(py::repr(py::cast(result.logits)))
+ ")";
})

Expand All @@ -39,8 +42,10 @@ namespace ctranslate2 {
throw py::index_error();
py::dict hypothesis;
hypothesis["tokens"] = result.hypotheses[i];
if (result.has_scores())
if (result.has_scores()){
hypothesis["score"] = result.scores[i];
hypothesis["logits"] = result.logits[i];
};
if (result.has_attention())
hypothesis["attention"] = result.attention[i];
return hypothesis;
Expand Down
2 changes: 1 addition & 1 deletion python/ctranslate2/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version information."""

__version__ = "3.24.0"
__version__ = "3.24.1"
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _maybe_add_library_root(lib_name):
"numpy",
"pyyaml>=5.3,<7",
],
include_package_data=True,
entry_points={
"console_scripts": [
"ct2-fairseq-converter=ctranslate2.converters.fairseq:main",
Expand Down
26 changes: 18 additions & 8 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -779,24 +779,29 @@ namespace ctranslate2 {
StorageView logits(dtype, device);
std::vector<dim_t> batch_offset(batch_size);
std::vector<DecodingResult> results(batch_size);

StorageView best_ids(DataType::INT32);
StorageView best_probs(dtype);
StorageView alive_seq(DataType::INT32);
StorageView attention_step;
StorageView attention_step_device(dtype, device);

const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids);

for (dim_t i = 0; i < batch_size; ++i) {
batch_offset[i] = i;
sample_from.at<int32_t>(i) = start_ids[i];
results[i].hypotheses.resize(1);
if (return_scores)
{
results[i].scores.resize(1, 0.f);
results[i].logits.resize(max_step);
};

if (return_attention)
results[i].attention.resize(1);
}

StorageView best_ids(DataType::INT32);
StorageView best_probs(dtype);
StorageView alive_seq(DataType::INT32);
StorageView attention_step;
StorageView attention_step_device(dtype, device);

const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids);

for (dim_t step = 0; step < max_step; ++step) {
convert_to_original_word_ids(decoder, sample_from);
decoder(start_step + step,
Expand Down Expand Up @@ -851,6 +856,8 @@ namespace ctranslate2 {
const size_t batch_id = batch_offset[i];
const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0;
const float score = best_probs.scalar_at<float>({i, 0});
// convert word_id from
const float log_prob = log_probs.scalar_at<float>({i, static_cast<int32_t>(word_id)});

if ((!is_eos(word_id, end_ids) || include_eos_in_hypotheses)
&& (return_prefix || step >= prefix_length)) {
Expand All @@ -862,7 +869,10 @@ namespace ctranslate2 {
}

if (return_scores)
{
results[batch_id].scores[0] += score;
results[batch_id].logits[step] = log_prob;
};

bool is_finished = ((is_eos(word_id, end_ids) && step >= prefix_length)
|| (is_last_step(step, max_length, prefix_length, return_prefix)));
Expand Down
5 changes: 4 additions & 1 deletion src/models/sequence_to_sequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ namespace ctranslate2 {

final_results.emplace_back(std::move(hypotheses),
std::move(result.scores),
std::move(result.attention));
std::move(result.attention),
std::move(result.logits));
}

return final_results;
Expand Down Expand Up @@ -462,6 +463,8 @@ namespace ctranslate2 {
result.scores.emplace_back(0);
if (options.return_attention)
result.attention.emplace_back(attention);
if (options.return_scores)
result.logits.emplace_back(0);
}

return true;
Expand Down

0 comments on commit 884b8df

Please sign in to comment.