Skip to content

Commit

Permalink
Fix conversion of OpenNMT-tf V1 checkpoints with the new converter (O…
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Jun 20, 2022
1 parent b5cc6c1 commit b1e6601
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 3 additions & 2 deletions python/ctranslate2/converters/opennmt_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,14 @@ def from_config(
if auto_config:
config_util.merge_config(config, model.auto_config())

data_config = config_util.try_prefix_paths(config["model_dir"], config["data"])
model.initialize(data_config)

checkpoint = Checkpoint.from_config(config, model)
checkpoint_path = checkpoint.restore(checkpoint_path=checkpoint_path)
if checkpoint_path is None:
raise RuntimeError("No checkpoint was restored")

data_config = config_util.try_prefix_paths(config["model_dir"], config["data"])
model.initialize(data_config)
model.create_variables()
return cls(model)

Expand Down
6 changes: 6 additions & 0 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,12 @@ def test_model_unload_while_async_translation():
"en.vocab",
ctranslate2.specs.TransformerSpec(6, 8),
),
(
"v1/checkpoint",
"ar.vocab",
"en.vocab",
None,
),
(
"v2/checkpoint",
"ar.vocab",
Expand Down

0 comments on commit b1e6601

Please sign in to comment.