diff --git a/src/SIL.Machine.Tool/CorpusCommandSpecBase.cs b/src/SIL.Machine.Tool/CorpusCommandSpecBase.cs index 2186239c1..be45ab6a5 100644 --- a/src/SIL.Machine.Tool/CorpusCommandSpecBase.cs +++ b/src/SIL.Machine.Tool/CorpusCommandSpecBase.cs @@ -1,6 +1,5 @@ using System.Collections.Generic; using System.IO; -using System.Linq; using McMaster.Extensions.CommandLineUtils; using SIL.Machine.Corpora; diff --git a/src/SIL.Machine.Tool/ParallelCorpusCommandSpec.cs b/src/SIL.Machine.Tool/ParallelCorpusCommandSpec.cs index 6a014303c..dd1d82e6f 100644 --- a/src/SIL.Machine.Tool/ParallelCorpusCommandSpec.cs +++ b/src/SIL.Machine.Tool/ParallelCorpusCommandSpec.cs @@ -1,5 +1,4 @@ -using System.Collections.Generic; -using System.IO; +using System.IO; using McMaster.Extensions.CommandLineUtils; using SIL.Machine.Corpora; using SIL.Machine.Tokenization; diff --git a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj index 5a69d6fb6..d326a66af 100644 --- a/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj +++ b/src/SIL.Machine.Translation.Thot/SIL.Machine.Translation.Thot.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/SIL.Machine.Translation.Thot/Thot.cs b/src/SIL.Machine.Translation.Thot/Thot.cs index 5f639f3fb..741fdf964 100644 --- a/src/SIL.Machine.Translation.Thot/Thot.cs +++ b/src/SIL.Machine.Translation.Thot/Thot.cs @@ -613,8 +613,10 @@ IReadOnlyList targetSegment public static IntPtr LoadSmtModel(ThotWordAlignmentModelType alignmentModelType, ThotSmtParameters parameters) { IntPtr handle = smtModel_create(GetAlignmentModelType(alignmentModelType, incremental: true)); - smtModel_loadTranslationModel(handle, parameters.TranslationModelFileNamePrefix); - smtModel_loadLanguageModel(handle, parameters.LanguageModelFileNamePrefix); + if (!smtModel_loadTranslationModel(handle, parameters.TranslationModelFileNamePrefix)) + throw new InvalidOperationException("Unable to load translation model."); + if (!smtModel_loadLanguageModel(handle, parameters.LanguageModelFileNamePrefix)) + throw new InvalidOperationException("Unable to load language model."); smtModel_setNonMonotonicity(handle, parameters.ModelNonMonotonicity); smtModel_setW(handle, parameters.ModelW); smtModel_setA(handle, parameters.ModelA); @@ -650,7 +652,10 @@ public static IntPtr CreateAlignmentModel(ThotWordAlignmentModelType type, IntPt public static IntPtr OpenAlignmentModel(ThotWordAlignmentModelType type, string prefFileName) { - return swAlignModel_open(GetAlignmentModelType(type, incremental: false), prefFileName); + IntPtr handle = swAlignModel_open(GetAlignmentModelType(type, incremental: false), prefFileName); + if (handle == IntPtr.Zero) + throw new InvalidOperationException("Unable to load word alignment model."); + return handle; } public static AlignmentModelType GetAlignmentModelType(ThotWordAlignmentModelType type, bool incremental) diff --git a/src/SIL.Machine.Translation.Thot/ThotSmtModel.cs b/src/SIL.Machine.Translation.Thot/ThotSmtModel.cs index 51ea9d357..88e790a0d 100644 --- a/src/SIL.Machine.Translation.Thot/ThotSmtModel.cs +++ b/src/SIL.Machine.Translation.Thot/ThotSmtModel.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -31,6 +32,12 @@ public ThotSmtModel(ThotWordAlignmentModelType wordAlignmentModelType, string cf public ThotSmtModel(ThotWordAlignmentModelType wordAlignmentModelType, ThotSmtParameters parameters) { + if (!File.Exists(parameters.TranslationModelFileNamePrefix + ".ttable")) + throw new FileNotFoundException("The translation model could not be found."); + + if (!File.Exists(parameters.LanguageModelFileNamePrefix)) + throw new FileNotFoundException("The language model could not be found."); + _decoderPool = new ObjectPool(MaxDecoderPoolSize, () => new ThotSmtDecoder(this)); Parameters = parameters; @@ -605,7 +612,8 @@ protected override void DisposeManagedResources() protected override void DisposeUnmanagedResources() { - Thot.smtModel_close(_handle); + if (_handle != IntPtr.Zero) + Thot.smtModel_close(_handle); } private double GetWordAlignmentScore( diff --git a/tests/SIL.Machine.Translation.Thot.Tests/ThotFastAlignWordAlignmentModelTests.cs b/tests/SIL.Machine.Translation.Thot.Tests/ThotFastAlignWordAlignmentModelTests.cs index 9b819eb66..0286494cb 100644 --- a/tests/SIL.Machine.Translation.Thot.Tests/ThotFastAlignWordAlignmentModelTests.cs +++ b/tests/SIL.Machine.Translation.Thot.Tests/ThotFastAlignWordAlignmentModelTests.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using SIL.Machine.Utils; namespace SIL.Machine.Translation.Thot { @@ -165,5 +166,14 @@ public void GetAvgTranslationScore_Symmetrized() double score = model.GetAvgTranslationScore(sourceSegment, targetSegment, alignment); Assert.That(score, Is.EqualTo(0.36).Within(0.01)); } + + [Test] + public void Constructor_ModelCorrupted() + { + using var tempDir = new TempDirectory("ThotFastAlignWordAlignmentModelTests"); + string modelPrefix = Path.Combine(tempDir.Path, "src_trg_invswm"); + File.WriteAllText(modelPrefix + ".src", "corrupted"); + Assert.Throws(() => new ThotFastAlignWordAlignmentModel(modelPrefix)); + } } } diff --git a/tests/SIL.Machine.Translation.Thot.Tests/ThotSmtModelTests.cs b/tests/SIL.Machine.Translation.Thot.Tests/ThotSmtModelTests.cs index ef9d135cf..9609840df 100644 --- a/tests/SIL.Machine.Translation.Thot.Tests/ThotSmtModelTests.cs +++ b/tests/SIL.Machine.Translation.Thot.Tests/ThotSmtModelTests.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using SIL.Machine.Utils; namespace SIL.Machine.Translation.Thot { @@ -189,6 +190,45 @@ public async Task GetWordGraphAsync_EmptySegment_FastAlign() Assert.That(wordGraph.IsEmpty, Is.True); } + [Test] + public void Constructor_ModelDoesNotExist() + { + Assert.Throws( + () => + new ThotSmtModel( + ThotWordAlignmentModelType.Hmm, + new ThotSmtParameters + { + TranslationModelFileNamePrefix = "does-not-exist", + LanguageModelFileNamePrefix = "does-not-exist" + } + ) + ); + } + + [Test] + public void Constructor_ModelCorrupted() + { + using var tempDir = new TempDirectory("ThotSmtModelTests"); + string tmDir = Path.Combine(tempDir.Path, "tm"); + Directory.CreateDirectory(tmDir); + File.WriteAllText(Path.Combine(tmDir, "src_trg.ttable"), "corrupted"); + string lmDir = Path.Combine(tempDir.Path, "lm"); + Directory.CreateDirectory(lmDir); + File.WriteAllText(Path.Combine(lmDir, "trg.lm"), "corrupted"); + Assert.Throws( + () => + new ThotSmtModel( + ThotWordAlignmentModelType.Hmm, + new ThotSmtParameters + { + TranslationModelFileNamePrefix = Path.Combine(tmDir, "src_trg"), + LanguageModelFileNamePrefix = Path.Combine(lmDir, "trg.lm") + } + ) + ); + } + private static ThotSmtModel CreateHmmModel() { return new ThotSmtModel(ThotWordAlignmentModelType.Hmm, TestHelpers.ToyCorpusHmmConfigFileName);