From b1e6601fc9ab3a8597a747c9768b9cba637138b1 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 20 Jun 2022 11:09:25 +0200 Subject: [PATCH] Fix conversion of OpenNMT-tf V1 checkpoints with the new converter (#845) --- python/ctranslate2/converters/opennmt_tf.py | 5 +++-- python/tests/test.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/ctranslate2/converters/opennmt_tf.py b/python/ctranslate2/converters/opennmt_tf.py index e23532bb0..61eabc30c 100644 --- a/python/ctranslate2/converters/opennmt_tf.py +++ b/python/ctranslate2/converters/opennmt_tf.py @@ -59,13 +59,14 @@ def from_config( if auto_config: config_util.merge_config(config, model.auto_config()) + data_config = config_util.try_prefix_paths(config["model_dir"], config["data"]) + model.initialize(data_config) + checkpoint = Checkpoint.from_config(config, model) checkpoint_path = checkpoint.restore(checkpoint_path=checkpoint_path) if checkpoint_path is None: raise RuntimeError("No checkpoint was restored") - data_config = config_util.try_prefix_paths(config["model_dir"], config["data"]) - model.initialize(data_config) model.create_variables() return cls(model) diff --git a/python/tests/test.py b/python/tests/test.py index 9147d9a52..1ded586b5 100644 --- a/python/tests/test.py +++ b/python/tests/test.py @@ -538,6 +538,12 @@ def test_model_unload_while_async_translation(): "en.vocab", ctranslate2.specs.TransformerSpec(6, 8), ), + ( + "v1/checkpoint", + "ar.vocab", + "en.vocab", + None, + ), ( "v2/checkpoint", "ar.vocab",