From 884b8df1184819989b27cea74322bec7f44048f4 Mon Sep 17 00:00:00 2001 From: prashanth058 Date: Mon, 17 Jun 2024 15:28:42 +0000 Subject: [PATCH] add logits --- include/ctranslate2/decoding.h | 2 ++ include/ctranslate2/translation.h | 6 +++++- python/cpp/translation_result.cc | 7 ++++++- python/ctranslate2/version.py | 2 +- python/setup.py | 1 + src/decoding.cc | 26 ++++++++++++++++++-------- src/models/sequence_to_sequence.cc | 5 ++++- 7 files changed, 37 insertions(+), 12 deletions(-) diff --git a/include/ctranslate2/decoding.h b/include/ctranslate2/decoding.h index 436280b46..4151a8e43 100644 --- a/include/ctranslate2/decoding.h +++ b/include/ctranslate2/decoding.h @@ -15,6 +15,8 @@ namespace ctranslate2 { std::vector> hypotheses; std::vector scores; std::vector>> attention; + std::vector logits; + // (max_decoding_steps) }; struct DecodingStepResult { diff --git a/include/ctranslate2/translation.h b/include/ctranslate2/translation.h index 8d2ec943a..207d1062d 100644 --- a/include/ctranslate2/translation.h +++ b/include/ctranslate2/translation.h @@ -87,6 +87,7 @@ namespace ctranslate2 { std::vector> hypotheses; std::vector scores; std::vector>> attention; + std::vector logits; TranslationResult(std::vector> hypotheses_) : hypotheses(std::move(hypotheses_)) @@ -95,10 +96,12 @@ namespace ctranslate2 { TranslationResult(std::vector> hypotheses_, std::vector scores_, - std::vector>> attention_) + std::vector>> attention_, + std::vector logits_) : hypotheses(std::move(hypotheses_)) , scores(std::move(scores_)) , attention(std::move(attention_)) + , logits(std::move(logits_)) { } @@ -109,6 +112,7 @@ namespace ctranslate2 { : hypotheses(num_hypotheses) , scores(with_score ? num_hypotheses : 0, static_cast(0)) , attention(with_attention ? num_hypotheses : 0) + , logits(with_score ? num_hypotheses : 0) { } diff --git a/python/cpp/translation_result.cc b/python/cpp/translation_result.cc index 3b8a0790b..8104f5df3 100644 --- a/python/cpp/translation_result.cc +++ b/python/cpp/translation_result.cc @@ -16,11 +16,14 @@ namespace ctranslate2 { "Score of each translation hypothesis (empty if :obj:`return_scores` was disabled).") .def_readonly("attention", &TranslationResult::attention, "Attention matrix of each translation hypothesis (empty if :obj:`return_attention` was disabled).") + .def_readonly("logits", &TranslationResult::logits, + "Logits for each decoding step") .def("__repr__", [](const TranslationResult& result) { return "TranslationResult(hypotheses=" + std::string(py::repr(py::cast(result.hypotheses))) + ", scores=" + std::string(py::repr(py::cast(result.scores))) + ", attention=" + std::string(py::repr(py::cast(result.attention))) + + ", logits=" + std::string(py::repr(py::cast(result.logits))) + ")"; }) @@ -39,8 +42,10 @@ namespace ctranslate2 { throw py::index_error(); py::dict hypothesis; hypothesis["tokens"] = result.hypotheses[i]; - if (result.has_scores()) + if (result.has_scores()){ hypothesis["score"] = result.scores[i]; + hypothesis["logits"] = result.logits[i]; + }; if (result.has_attention()) hypothesis["attention"] = result.attention[i]; return hypothesis; diff --git a/python/ctranslate2/version.py b/python/ctranslate2/version.py index 4416d629f..969042c40 100644 --- a/python/ctranslate2/version.py +++ b/python/ctranslate2/version.py @@ -1,3 +1,3 @@ """Version information.""" -__version__ = "3.24.0" +__version__ = "3.24.1" diff --git a/python/setup.py b/python/setup.py index 3ed3304fc..25dbeb0b9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -112,6 +112,7 @@ def _maybe_add_library_root(lib_name): "numpy", "pyyaml>=5.3,<7", ], + include_package_data=True, entry_points={ "console_scripts": [ "ct2-fairseq-converter=ctranslate2.converters.fairseq:main", diff --git a/src/decoding.cc b/src/decoding.cc index 4c541e490..dbe0d35f0 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -779,24 +779,29 @@ namespace ctranslate2 { StorageView logits(dtype, device); std::vector batch_offset(batch_size); std::vector results(batch_size); + + StorageView best_ids(DataType::INT32); + StorageView best_probs(dtype); + StorageView alive_seq(DataType::INT32); + StorageView attention_step; + StorageView attention_step_device(dtype, device); + + const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids); + for (dim_t i = 0; i < batch_size; ++i) { batch_offset[i] = i; sample_from.at(i) = start_ids[i]; results[i].hypotheses.resize(1); if (return_scores) + { results[i].scores.resize(1, 0.f); + results[i].logits.resize(max_step); + }; + if (return_attention) results[i].attention.resize(1); } - StorageView best_ids(DataType::INT32); - StorageView best_probs(dtype); - StorageView alive_seq(DataType::INT32); - StorageView attention_step; - StorageView attention_step_device(dtype, device); - - const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids); - for (dim_t step = 0; step < max_step; ++step) { convert_to_original_word_ids(decoder, sample_from); decoder(start_step + step, @@ -851,6 +856,8 @@ namespace ctranslate2 { const size_t batch_id = batch_offset[i]; const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0; const float score = best_probs.scalar_at({i, 0}); + // convert word_id from + const float log_prob = log_probs.scalar_at({i, static_cast(word_id)}); if ((!is_eos(word_id, end_ids) || include_eos_in_hypotheses) && (return_prefix || step >= prefix_length)) { @@ -862,7 +869,10 @@ namespace ctranslate2 { } if (return_scores) + { results[batch_id].scores[0] += score; + results[batch_id].logits[step] = log_prob; + }; bool is_finished = ((is_eos(word_id, end_ids) && step >= prefix_length) || (is_last_step(step, max_length, prefix_length, return_prefix))); diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index ed4bb214b..346ec43b6 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -423,7 +423,8 @@ namespace ctranslate2 { final_results.emplace_back(std::move(hypotheses), std::move(result.scores), - std::move(result.attention)); + std::move(result.attention), + std::move(result.logits)); } return final_results; @@ -462,6 +463,8 @@ namespace ctranslate2 { result.scores.emplace_back(0); if (options.return_attention) result.attention.emplace_back(attention); + if (options.return_scores) + result.logits.emplace_back(0); } return true;