Skip to content

Commit

Permalink
Use parallel data when inferencing for word alignment
Browse files Browse the repository at this point in the history
Enkidu93 committed Jan 14, 2025
1 parent 97b599f commit 3a4047a
Showing 6 changed files with 103 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -135,9 +135,9 @@ await ParallelCorpusPreprocessingService.PreprocessAsync(
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
trainCount++;
},
async (row, corpus) =>
async (row, isInTrainingData, corpus) =>
{
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0)
if (row.SourceSegment.Length > 0 && !isInTrainingData)
{
pretranslateWriter.WriteStartObject();
pretranslateWriter.WriteString("corpusId", corpus.Id);
Original file line number Diff line number Diff line change
@@ -55,9 +55,9 @@ await ParallelCorpusPreprocessingService.PreprocessAsync(
trainCount++;
}
},
async (row, corpus) =>
async (row, isInTrainingData, corpus) =>
{
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0 && !isInTrainingData)
{
inferenceWriter.WriteStartObject();
inferenceWriter.WriteString("corpusId", corpus.Id);
14 changes: 12 additions & 2 deletions src/Serval/test/Serval.E2ETests/ServalApiTests.cs
Original file line number Diff line number Diff line change
@@ -480,9 +480,19 @@ public async Task ParatextProjectNmtJobAsync()
public async Task GetWordAlignment()
{
string engineId = await _helperClient.CreateNewEngineAsync("Statistical", "es", "en", "STAT1");
string[] books = ["1JN.txt", "2JN.txt", "3JN.txt", "MAT.txt"];
string[] books = ["1JN.txt", "2JN.txt", "MAT.txt"];
ParallelCorpusConfig train_corpus = await _helperClient.MakeParallelTextCorpus(books, "es", "en", false);
await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, train_corpus, false);
ParallelCorpusConfig test_corpus = await _helperClient.MakeParallelTextCorpus(["3JN.txt"], "es", "en", false);
string train_corpusId = await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, train_corpus, false);
string corpusId = await _helperClient.AddParallelTextCorpusToEngineAsync(engineId, test_corpus, true);
_helperClient.WordAlignmentBuildConfig.TrainOn =
[
new TrainingCorpusConfig2() { ParallelCorpusId = train_corpusId }
];
_helperClient.WordAlignmentBuildConfig.WordAlignOn =
[
new WordAlignmentCorpusConfig() { ParallelCorpusId = corpusId }
];
await _helperClient.BuildEngineAsync(engineId);
WordAlignmentResult tResult = await _helperClient.WordAlignmentEnginesClient.GetWordAlignmentAsync(
engineId,
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ public interface IParallelCorpusPreprocessingService
Task PreprocessAsync(
IReadOnlyList<ParallelCorpus> corpora,
Func<Row, Task> train,
Func<Row, ParallelCorpus, Task> pretranslate,
Func<Row, bool, ParallelCorpus, Task> pretranslate,
bool useKeyTerms = false
);
}
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ internal int Seed
public async Task PreprocessAsync(
IReadOnlyList<ParallelCorpus> corpora,
Func<Row, Task> train,
Func<Row, ParallelCorpus, Task> pretranslate,
Func<Row, bool, ParallelCorpus, Task> inference,
bool useKeyTerms = false
)
{
@@ -57,6 +57,11 @@ public async Task PreprocessAsync(
.Select(tc => FilterTrainingCorpora(tc.Corpus, tc.TextCorpus))
.ToArray();

ITextCorpus targetPretranslateCorpus = targetCorpora
.Select(tc => FilterPretranslateCorpora(tc.Corpus, tc.TextCorpus))
.ToArray()
.ChooseRandom(Seed);

ITextCorpus sourceTrainingCorpus = sourceTrainingCorpora.ChooseRandom(Seed);
if (sourceTrainingCorpus.IsScripture())
{
@@ -113,14 +118,16 @@ ParallelTextRow row in parallelKeyTermsCorpus.DistinctBy(row =>
}
ITextCorpus sourcePretranslateCorpus = sourcePretranslateCorpora.ChooseFirst();

IParallelTextCorpus pretranslateCorpus = sourcePretranslateCorpus.AlignRows(
targetCorpus,
allSourceRows: true
);
INParallelTextCorpus pretranslateCorpus = new ITextCorpus[]
{
sourcePretranslateCorpus,
targetPretranslateCorpus,
targetCorpus
}.AlignMany([true, false, false]);

foreach (Row row in CollapseRanges(pretranslateCorpus.ToArray()))
foreach ((Row row, bool isInTrainingData) in CollapsePretranslateRanges(pretranslateCorpus.ToArray()))
{
await pretranslate(row, corpus);
await inference(row, isInTrainingData, corpus);
}
}
}
@@ -229,6 +236,72 @@ private static IEnumerable<Row> CollapseRanges(ParallelTextRow[] rows)
}
}

private static IEnumerable<(Row, bool)> CollapsePretranslateRanges(NParallelTextRow[] rows)
{
StringBuilder srcSegBuffer = new();
StringBuilder trgSegBuffer = new();
List<object> refs = [];
string textId = "";
bool hasUnfinishedRange = false;
bool isInTrainingData = false;

foreach (NParallelTextRow row in rows)
{
if (
hasUnfinishedRange
&& (!row.IsInRange(0) || row.IsRangeStart(0))
&& (!row.IsInRange(1) || row.IsRangeStart(1))
&& (!row.IsInRange(2) || row.IsRangeStart(2))
)
{
yield return (
new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1),
isInTrainingData
);

srcSegBuffer.Clear();
trgSegBuffer.Clear();
refs.Clear();
isInTrainingData = false;
hasUnfinishedRange = false;
}

textId = row.TextId;
refs.AddRange(row.NRefs[2]);
isInTrainingData = isInTrainingData || row.Text(2).Length > 0;

if (row.Text(0).Length > 0)
{
if (srcSegBuffer.Length > 0)
srcSegBuffer.Append(' ');
srcSegBuffer.Append(row.Text(0));
}
if (row.Text(1).Length > 0)
{
if (trgSegBuffer.Length > 0)
trgSegBuffer.Append(' ');
trgSegBuffer.Append(row.Text(1));
}

if (row.IsInRange(0) || row.IsInRange(1) || row.IsInRange(2))
{
hasUnfinishedRange = true;
continue;
}

yield return (new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), isInTrainingData);

srcSegBuffer.Clear();
trgSegBuffer.Clear();
refs.Clear();
isInTrainingData = false;
}
if (hasUnfinishedRange)
{
yield return (new Row(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1), isInTrainingData);
}
}

private static bool IsScriptureRow(TextRow parallelTextRow)
{
return parallelTextRow.Ref is ScriptureRef sr && sr.IsVerse;
Original file line number Diff line number Diff line change
@@ -81,14 +81,18 @@ await processor.PreprocessAsync(
trainCount++;
return Task.CompletedTask;
},
(row, _) =>
(row, isInTrainingData, _) =>
{
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length == 0)
if (row.SourceSegment.Length > 0 && !isInTrainingData)
{
pretranslateCount++;
}

return Task.CompletedTask;
},
false
);

Assert.Multiple(() =>
{
Assert.That(trainCount, Is.EqualTo(2));

0 comments on commit 3a4047a

Please sign in to comment.