Skip to content

Commit

Permalink
Move embeddings scale initialization in a single function (OpenNMT#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Jun 17, 2022
1 parent 08eed4a commit 2c8db5e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 10 deletions.
1 change: 0 additions & 1 deletion include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ namespace ctranslate2 {
const StorageView& _embeddings;
const DataType _output_type;
const StorageView* _qscale;
std::unique_ptr<const StorageView> _scale;
};

// This enum order should remain fixed.
Expand Down
7 changes: 0 additions & 7 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ namespace ctranslate2 {
, _output_type(get_default_float_type(model.effective_compute_type()))
, _qscale(model.get_variable_if_exists(scope + "/weight_scale"))
{
if (model.get_flag_with_default(scope + "/multiply_by_sqrt_depth", true)) {
const StorageView scale(std::sqrt(static_cast<float>(_embeddings.dim(1))));
_scale = std::make_unique<StorageView>(scale.to(_output_type));
}
}

DataType Embeddings::output_type() const {
Expand Down Expand Up @@ -81,9 +77,6 @@ namespace ctranslate2 {
} else {
_gather_op(_embeddings, ids, output);
}

if (_scale)
ops::Mul()(output, *_scale, output);
}


Expand Down
6 changes: 4 additions & 2 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,15 @@ namespace ctranslate2 {
const std::string& scope,
const Layer& embeddings) {
const auto* scale = model.get_variable_if_exists(scope + "/scale_embeddings");

// Backward compatibility with older models.
if (!scale)
return nullptr;
scale = model.get_variable_if_exists(scope + "/embeddings/multiply_by_sqrt_depth");

StorageView value;

// The attribute can either be a boolean flag or the actual scale value.
if (scale->dtype() == DataType::INT8 && scale->as_scalar<int8_t>())
if (!scale || (scale->dtype() == DataType::INT8 && scale->as_scalar<int8_t>()))
value = StorageView(std::sqrt(static_cast<float>(embeddings.output_size())));
else if (scale->dtype() != DataType::INT8 && scale->as_scalar<float>() != 1.f)
value = *scale;
Expand Down

0 comments on commit 2c8db5e

Please sign in to comment.