diff --git a/src/SIL.Machine.AspNetCore/Models/Corpus.cs b/src/SIL.Machine.AspNetCore/Models/Corpus.cs index ae8b76240..a6847555d 100644 --- a/src/SIL.Machine.AspNetCore/Models/Corpus.cs +++ b/src/SIL.Machine.AspNetCore/Models/Corpus.cs @@ -5,12 +5,10 @@ public record Corpus public required string Id { get; init; } public required string SourceLanguage { get; init; } public required string TargetLanguage { get; init; } - public required bool TrainOnAll { get; init; } - public required bool PretranslateAll { get; init; } public IReadOnlyDictionary>? TrainOnChapters { get; init; } public IReadOnlyDictionary>? PretranslateChapters { get; init; } - public required HashSet TrainOnTextIds { get; init; } - public required HashSet PretranslateTextIds { get; init; } + public required HashSet? TrainOnTextIds { get; init; } + public required HashSet? PretranslateTextIds { get; init; } public required IReadOnlyList SourceFiles { get; init; } public required IReadOnlyList TargetFiles { get; init; } } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs index dc00851eb..a7f86e20d 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs @@ -149,7 +149,7 @@ CancellationToken cancellationToken continue; } - Row[] trainRows = rows.Where(r => IsIncluded(r, corpus.TrainOnChapters)).Cast().ToArray(); + Row[] trainRows = rows.Where(row => IsInTrain(row, corpus)).Cast().ToArray(); if (trainRows.Length > 0) { Row row = trainRows[0]; @@ -187,7 +187,7 @@ CancellationToken cancellationToken foreach (Row row in AlignPretranslateCorpus(corpus, sourceTextCorpora[0], targetTextCorpus)) { if ( - IsIncluded(row, corpus.PretranslateChapters) + IsInPretranslate(row, corpus) && row.SourceSegment.Length > 0 && (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus)) ) @@ -231,22 +231,32 @@ JobCompletionStatus completionStatus } } - private static bool IsInTrain(Row row, Corpus corpus) + private static bool IsInTrain(Row? row, Corpus corpus) { - if (corpus.TrainOnChapters is not null) - { - if (row.Refs.Any(r => IsInChapters(corpus.TrainOnChapters, r))) - return true; - } - return corpus.TrainOnAll || corpus.TrainOnTextIds.Contains(row.TextId); + return IsIncluded(row, corpus.TrainOnTextIds, corpus.TrainOnChapters); } - private static bool IsIncluded(Row? row, IReadOnlyDictionary>? chapters) + private static bool IsInPretranslate(Row? row, Corpus corpus) + { + return IsIncluded(row, corpus.PretranslateTextIds, corpus.PretranslateChapters); + } + + private static bool IsIncluded( + Row? row, + IReadOnlySet? textIds, + IReadOnlyDictionary>? chapters + ) { if (row is null) return false; if (chapters is not null) + { return row.Refs.Any(r => IsInChapters(chapters, r)); + } + if (textIds is not null) + { + return textIds.Contains(row.TextId); + } return true; } @@ -264,14 +274,11 @@ private static bool IsInChapters(IReadOnlyDictionary> bookC ITextCorpus trgCorpus ) { - if (!corpus.TrainOnAll) - { - IEnumerable textIds = corpus.TrainOnChapters is not null - ? corpus.TrainOnChapters.Keys - : corpus.TrainOnTextIds; - srcCorpora = srcCorpora.Select(sc => sc.FilterTexts(textIds)).ToArray(); - trgCorpus = trgCorpus.FilterTexts(textIds); - } + IEnumerable? textIds = corpus.TrainOnChapters is not null + ? corpus.TrainOnChapters.Keys + : corpus.TrainOnTextIds; + srcCorpora = srcCorpora.Select(sc => sc.FilterTexts(textIds)).ToArray(); + trgCorpus = trgCorpus.FilterTexts(textIds); if (trgCorpus.IsScripture()) { @@ -388,14 +395,11 @@ ITextCorpus trgCorpus private static IEnumerable AlignPretranslateCorpus(Corpus corpus, ITextCorpus srcCorpus, ITextCorpus trgCorpus) { - if (!corpus.PretranslateAll) - { - IEnumerable textIds = corpus.PretranslateChapters is not null - ? corpus.PretranslateChapters.Keys - : corpus.PretranslateTextIds; - srcCorpus = srcCorpus.FilterTexts(textIds); - trgCorpus = trgCorpus.FilterTexts(textIds); - } + IEnumerable? textIds = corpus.PretranslateChapters is not null + ? corpus.PretranslateChapters.Keys + : corpus.PretranslateTextIds; + srcCorpus = srcCorpus.FilterTexts(textIds); + trgCorpus = trgCorpus.FilterTexts(textIds); int rowCount = 0; StringBuilder srcSegBuffer = new(); diff --git a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs index fe56df75c..ea712e637 100644 --- a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs +++ b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs @@ -285,23 +285,28 @@ private static Serval.Translation.V1.Phrase Map(Translation.Phrase source) private static Models.Corpus Map(Serval.Translation.V1.Corpus source) { + var pretranslateChapters = source.PretranslateChapters.ToDictionary( + kvp => kvp.Key, + kvp => kvp.Value.Chapters.ToHashSet() + ); + FilterChoice pretranslateFilter = GetFilterChoice(source.PretranslateAll, pretranslateChapters); + + var trainOnChapters = source.TrainOnChapters.ToDictionary( + kvp => kvp.Key, + kvp => kvp.Value.Chapters.ToHashSet() + ); + FilterChoice trainingFilter = GetFilterChoice(source.TrainOnAll, trainOnChapters); + return new Models.Corpus { Id = source.Id, SourceLanguage = source.SourceLanguage, TargetLanguage = source.TargetLanguage, - TrainOnAll = source.TrainOnAll, - PretranslateAll = source.PretranslateAll, - TrainOnChapters = source.TrainOnChapters.ToDictionary( - kvp => kvp.Key, - kvp => kvp.Value.Chapters.ToHashSet() - ), - PretranslateChapters = source.PretranslateChapters.ToDictionary( - kvp => kvp.Key, - kvp => kvp.Value.Chapters.ToHashSet() - ), - TrainOnTextIds = source.TrainOnTextIds.ToHashSet(), - PretranslateTextIds = source.PretranslateTextIds.ToHashSet(), + TrainOnChapters = trainingFilter == FilterChoice.Chapters ? trainOnChapters : null, + PretranslateChapters = pretranslateFilter == FilterChoice.Chapters ? pretranslateChapters : null, + TrainOnTextIds = trainingFilter == FilterChoice.TextIds ? source.TrainOnTextIds.ToHashSet() : null, + PretranslateTextIds = + pretranslateFilter == FilterChoice.TextIds ? source.PretranslateTextIds.ToHashSet() : null, SourceFiles = source.SourceFiles.Select(Map).ToList(), TargetFiles = source.TargetFiles.Select(Map).ToList() }; @@ -316,4 +321,24 @@ private static Models.CorpusFile Map(Serval.Translation.V1.CorpusFile source) TextId = source.TextId }; } + + private enum FilterChoice + { + Chapters, + TextIds, + None + } + + private static FilterChoice GetFilterChoice(bool all, IReadOnlyDictionary> chapters) + { + if (all) + return FilterChoice.None; + + // Only either textIds or Scripture Range will be used at a time + // TextIds may be an empty array, so prefer that if both are empty (which applies to both scripture and text) + if (chapters.Count == 0) + return FilterChoice.TextIds; + else + return FilterChoice.Chapters; + } } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs index b3ec04de0..cf1b8002d 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtPreprocessBuildJobTests.cs @@ -25,7 +25,7 @@ public async Task RunAsync_FilterOutEverything() public async Task RunAsync_TrainOnAll() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultTextFileCorpus with { TrainOnAll = true }; + Corpus corpus1 = env.DefaultTextFileCorpus with { TrainOnTextIds = null }; await env.RunBuildJobAsync(corpus1); @@ -61,7 +61,7 @@ public async Task RunAsync_TrainOnTextIds() public async Task RunAsync_TrainAndPretranslateAll() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateAll = true, TrainOnAll = true }; + Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = null, TrainOnTextIds = null }; await env.RunBuildJobAsync(corpus1); @@ -72,7 +72,7 @@ public async Task RunAsync_TrainAndPretranslateAll() public async Task RunAsync_PretranslateAll() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateAll = true }; + Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = null }; await env.RunBuildJobAsync(corpus1); @@ -83,7 +83,7 @@ public async Task RunAsync_PretranslateAll() public async Task RunAsync_PretranslateTextIds() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = ["textId1"], TrainOnAll = true }; + Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = ["textId1"], TrainOnTextIds = null }; await env.RunBuildJobAsync(corpus1); @@ -177,7 +177,11 @@ public async Task RunAsync_TrainOnChapters() public async Task RunAsync_MixedSource_Paratext() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultMixedSourceParatextCorpus with { TrainOnAll = true, PretranslateAll = true }; + Corpus corpus1 = env.DefaultMixedSourceParatextCorpus with + { + TrainOnTextIds = null, + PretranslateTextIds = null + }; await env.RunBuildJobAsync(corpus1, useKeyTerms: false); @@ -196,7 +200,13 @@ public async Task RunAsync_MixedSource_Paratext() public async Task RunAsync_MixedSource_Text() { using TestEnvironment env = new(); - Corpus corpus1 = env.DefaultMixedSourceTextFileCorpus with { TrainOnAll = true, PretranslateAll = true }; + Corpus corpus1 = env.DefaultMixedSourceTextFileCorpus with + { + TrainOnTextIds = null, + PretranslateTextIds = null, + TrainOnChapters = null, + PretranslateChapters = null + }; await env.RunBuildJobAsync(corpus1); @@ -268,8 +278,6 @@ public TestEnvironment() Id = "corpusId1", SourceLanguage = "es", TargetLanguage = "en", - PretranslateAll = false, - TrainOnAll = false, PretranslateTextIds = [], TrainOnTextIds = [], SourceFiles = [TextFile("source1")], @@ -281,8 +289,6 @@ public TestEnvironment() Id = "corpusId1", SourceLanguage = "es", TargetLanguage = "en", - PretranslateAll = false, - TrainOnAll = false, PretranslateTextIds = [], TrainOnTextIds = [], SourceFiles = [TextFile("source1"), TextFile("source2")], @@ -294,8 +300,6 @@ public TestEnvironment() Id = "corpusId1", SourceLanguage = "es", TargetLanguage = "en", - PretranslateAll = false, - TrainOnAll = false, PretranslateTextIds = [], TrainOnTextIds = [], SourceFiles = [ParatextFile("pt-source1")], @@ -307,8 +311,6 @@ public TestEnvironment() Id = "corpusId1", SourceLanguage = "es", TargetLanguage = "en", - PretranslateAll = false, - TrainOnAll = false, PretranslateTextIds = [], TrainOnTextIds = [], SourceFiles = [ParatextFile("pt-source1"), ParatextFile("pt-source2")], diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index c3bb9b9c3..c33974bf9 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -41,10 +41,8 @@ await env.Service.StartBuildAsync( TargetLanguage = "en", SourceFiles = [], TargetFiles = [], - TrainOnAll = true, - PretranslateAll = true, - TrainOnTextIds = [], - PretranslateTextIds = [] + TrainOnTextIds = null, + PretranslateTextIds = null } } );