From 3a4047a883d649e15001c674ed23fae624a1ec55 Mon Sep 17 00:00:00 2001 From: Enkidu93 Date: Tue, 14 Jan 2025 09:05:14 -0500 Subject: [PATCH] Use parallel data when inferencing for word alignment --- .../Services/PreprocessBuildJob.cs | 4 +- .../WordAlignmentPreprocessBuildJob.cs | 4 +- .../test/Serval.E2ETests/ServalApiTests.cs | 14 ++- .../IParallelCorpusPreprocessingService.cs | 2 +- .../ParallelCorpusPreprocessingService.cs | 87 +++++++++++++++++-- .../ParallelCorpusProcessingServiceTests.cs | 8 +- 6 files changed, 103 insertions(+), 16 deletions(-) diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index 7fdd646e..bdee4885 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -135,9 +135,9 @@ await ParallelCorpusPreprocessingService.PreprocessAsync( if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) trainCount++; }, - async (row, corpus) => + async (row, isInTrainingData, corpus) => { - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) + if (row.SourceSegment.Length > 0 && !isInTrainingData) { pretranslateWriter.WriteStartObject(); pretranslateWriter.WriteString("corpusId", corpus.Id); diff --git a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs index 29777dfd..8fd4e950 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/WordAlignmentPreprocessBuildJob.cs @@ -55,9 +55,9 @@ await ParallelCorpusPreprocessingService.PreprocessAsync( trainCount++; } }, - async (row, corpus) => + async (row, isInTrainingData, corpus) => { - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0) + if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0 && !isInTrainingData) { inferenceWriter.WriteStartObject(); inferenceWriter.WriteString("corpusId", corpus.Id); diff --git a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs index 8b280518..89819615 100644 --- a/src/Serval/test/Serval.E2ETests/ServalApiTests.cs +++ b/src/Serval/test/Serval.E2ETests/ServalApiTests.cs @@ -480,9 +480,19 @@ public async Task ParatextProjectNmtJobAsync() public async Task GetWordAlignment() { string engineId = await _helperClient.CreateNewEngineAsync("Statistical", "es", "en", "STAT1"); - string[] books = ["1JN.txt", "2JN.txt", "3JN.txt", "MAT.txt"]; + string[] books = ["1JN.txt", "2JN.txt", "MAT.txt"]; ParallelCorpusConfig train_corpus = await _helperClient.MakeParallelTextCorpus(books, "es", "en", false); - await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, train_corpus, false); + ParallelCorpusConfig test_corpus = await _helperClient.MakeParallelTextCorpus(["3JN.txt"], "es", "en", false); + string train_corpusId = await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, train_corpus, false); + string corpusId = await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, test_corpus, true); + _helperClient.WordAlignmentBuildConfig.TrainOn = + [ + new TrainingCorpusConfig2() { ParallelCorpusId = train_corpusId } + ]; + _helperClient.WordAlignmentBuildConfig.WordAlignOn = + [ + new WordAlignmentCorpusConfig() { ParallelCorpusId = corpusId } + ]; await _helperClient.BuildEngineAsync(engineId); WordAlignmentResult tResult = await _helperClient.WordAlignmentEnginesClient.GetWordAlignmentAsync( engineId, diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs index 5e5fa959..32912734 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/IParallelCorpusPreprocessingService.cs @@ -5,7 +5,7 @@ public interface IParallelCorpusPreprocessingService Task PreprocessAsync( IReadOnlyList corpora, Func train, - Func pretranslate, + Func pretranslate, bool useKeyTerms = false ); } diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs index b0cb42b6..786de272 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Services/ParallelCorpusPreprocessingService.cs @@ -28,7 +28,7 @@ internal int Seed public async Task PreprocessAsync( IReadOnlyList corpora, Func train, - Func pretranslate, + Func inference, bool useKeyTerms = false ) { @@ -57,6 +57,11 @@ public async Task PreprocessAsync( .Select(tc => FilterTrainingCorpora(tc.Corpus, tc.TextCorpus)) .ToArray(); + ITextCorpus targetPretranslateCorpus = targetCorpora + .Select(tc => FilterPretranslateCorpora(tc.Corpus, tc.TextCorpus)) + .ToArray() + .ChooseRandom(Seed); + ITextCorpus sourceTrainingCorpus = sourceTrainingCorpora.ChooseRandom(Seed); if (sourceTrainingCorpus.IsScripture()) { @@ -113,14 +118,16 @@ ParallelTextRow row in parallelKeyTermsCorpus.DistinctBy(row => } ITextCorpus sourcePretranslateCorpus = sourcePretranslateCorpora.ChooseFirst(); - IParallelTextCorpus pretranslateCorpus = sourcePretranslateCorpus.AlignRows( - targetCorpus, - allSourceRows: true - ); + INParallelTextCorpus pretranslateCorpus = new ITextCorpus[] + { + sourcePretranslateCorpus, + targetPretranslateCorpus, + targetCorpus + }.AlignMany([true, false, false]); - foreach (Row row in CollapseRanges(pretranslateCorpus.ToArray())) + foreach ((Row row, bool isInTrainingData) in CollapsePretranslateRanges(pretranslateCorpus.ToArray())) { - await pretranslate(row, corpus); + await inference(row, isInTrainingData, corpus); } } } @@ -229,6 +236,72 @@ private static IEnumerable CollapseRanges(ParallelTextRow[] rows) } } + private static IEnumerable<(Row, bool)> CollapsePretranslateRanges(NParallelTextRow[] rows) + { + StringBuilder srcSegBuffer = new(); + StringBuilder trgSegBuffer = new(); + List refs = []; + string textId = ""; + bool hasUnfinishedRange = false; + bool isInTrainingData = false; + + foreach (NParallelTextRow row in rows) + { + if ( + hasUnfinishedRange + && (!row.IsInRange(0) || row.IsRangeStart(0)) + && (!row.IsInRange(1) || row.IsRangeStart(1)) + && (!row.IsInRange(2) || row.IsRangeStart(2)) + ) + { + yield return ( + new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), + isInTrainingData + ); + + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + isInTrainingData = false; + hasUnfinishedRange = false; + } + + textId = row.TextId; + refs.AddRange(row.NRefs[2]); + isInTrainingData = isInTrainingData || row.Text(2).Length > 0; + + if (row.Text(0).Length > 0) + { + if (srcSegBuffer.Length > 0) + srcSegBuffer.Append(' '); + srcSegBuffer.Append(row.Text(0)); + } + if (row.Text(1).Length > 0) + { + if (trgSegBuffer.Length > 0) + trgSegBuffer.Append(' '); + trgSegBuffer.Append(row.Text(1)); + } + + if (row.IsInRange(0) || row.IsInRange(1) || row.IsInRange(2)) + { + hasUnfinishedRange = true; + continue; + } + + yield return (new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), isInTrainingData); + + srcSegBuffer.Clear(); + trgSegBuffer.Clear(); + refs.Clear(); + isInTrainingData = false; + } + if (hasUnfinishedRange) + { + yield return (new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), isInTrainingData); + } + } + private static bool IsScriptureRow(TextRow parallelTextRow) { return parallelTextRow.Ref is ScriptureRef sr && sr.IsVerse; diff --git a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs index cdd1884f..4c81bc92 100644 --- a/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs +++ b/src/ServiceToolkit/test/SIL.ServiceToolkit.Tests/Services/ParallelCorpusProcessingServiceTests.cs @@ -81,14 +81,18 @@ await processor.PreprocessAsync( trainCount++; return Task.CompletedTask; }, - (row, _) => + (row, isInTrainingData, _) => { - if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0) + if (row.SourceSegment.Length > 0 && !isInTrainingData) + { pretranslateCount++; + } + return Task.CompletedTask; }, false ); + Assert.Multiple(() => { Assert.That(trainCount, Is.EqualTo(2));