Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarify and fix logic for textId and Chapters. #209

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/SIL.Machine.AspNetCore/Models/Corpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ public record Corpus
public required string Id { get; init; }
public required string SourceLanguage { get; init; }
public required string TargetLanguage { get; init; }
public required bool TrainOnAll { get; init; }
public required bool PretranslateAll { get; init; }
public IReadOnlyDictionary<string, HashSet<int>>? TrainOnChapters { get; init; }
public IReadOnlyDictionary<string, HashSet<int>>? PretranslateChapters { get; init; }
public required HashSet<string> TrainOnTextIds { get; init; }
public required HashSet<string> PretranslateTextIds { get; init; }
public required HashSet<string>? TrainOnTextIds { get; init; }
public required HashSet<string>? PretranslateTextIds { get; init; }
public required IReadOnlyList<CorpusFile> SourceFiles { get; init; }
public required IReadOnlyList<CorpusFile> TargetFiles { get; init; }
}
56 changes: 30 additions & 26 deletions src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ CancellationToken cancellationToken
continue;
}

Row[] trainRows = rows.Where(r => IsIncluded(r, corpus.TrainOnChapters)).Cast<Row>().ToArray();
Row[] trainRows = rows.Where(row => IsInTrain(row, corpus)).Cast<Row>().ToArray();
if (trainRows.Length > 0)
{
Row row = trainRows[0];
Expand Down Expand Up @@ -187,7 +187,7 @@ CancellationToken cancellationToken
foreach (Row row in AlignPretranslateCorpus(corpus, sourceTextCorpora[0], targetTextCorpus))
{
if (
IsIncluded(row, corpus.PretranslateChapters)
IsInPretranslate(row, corpus)
&& row.SourceSegment.Length > 0
&& (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus))
)
Expand Down Expand Up @@ -231,22 +231,32 @@ JobCompletionStatus completionStatus
}
}

private static bool IsInTrain(Row row, Corpus corpus)
private static bool IsInTrain(Row? row, Corpus corpus)
{
if (corpus.TrainOnChapters is not null)
{
if (row.Refs.Any(r => IsInChapters(corpus.TrainOnChapters, r)))
return true;
}
return corpus.TrainOnAll || corpus.TrainOnTextIds.Contains(row.TextId);
return IsIncluded(row, corpus.TrainOnTextIds, corpus.TrainOnChapters);
}

private static bool IsIncluded(Row? row, IReadOnlyDictionary<string, HashSet<int>>? chapters)
private static bool IsInPretranslate(Row? row, Corpus corpus)
{
return IsIncluded(row, corpus.PretranslateTextIds, corpus.PretranslateChapters);
}

private static bool IsIncluded(
Row? row,
IReadOnlySet<string>? textIds,
IReadOnlyDictionary<string, HashSet<int>>? chapters
)
{
if (row is null)
return false;
if (chapters is not null)
{
return row.Refs.Any(r => IsInChapters(chapters, r));
}
if (textIds is not null)
{
return textIds.Contains(row.TextId);
}
return true;
}

Expand All @@ -264,14 +274,11 @@ private static bool IsInChapters(IReadOnlyDictionary<string, HashSet<int>> bookC
ITextCorpus trgCorpus
)
{
if (!corpus.TrainOnAll)
{
IEnumerable<string> textIds = corpus.TrainOnChapters is not null
? corpus.TrainOnChapters.Keys
: corpus.TrainOnTextIds;
srcCorpora = srcCorpora.Select(sc => sc.FilterTexts(textIds)).ToArray();
trgCorpus = trgCorpus.FilterTexts(textIds);
}
IEnumerable<string>? textIds = corpus.TrainOnChapters is not null
? corpus.TrainOnChapters.Keys
: corpus.TrainOnTextIds;
srcCorpora = srcCorpora.Select(sc => sc.FilterTexts(textIds)).ToArray();
trgCorpus = trgCorpus.FilterTexts(textIds);

if (trgCorpus.IsScripture())
{
Expand Down Expand Up @@ -388,14 +395,11 @@ ITextCorpus trgCorpus

private static IEnumerable<Row> AlignPretranslateCorpus(Corpus corpus, ITextCorpus srcCorpus, ITextCorpus trgCorpus)
{
if (!corpus.PretranslateAll)
{
IEnumerable<string> textIds = corpus.PretranslateChapters is not null
? corpus.PretranslateChapters.Keys
: corpus.PretranslateTextIds;
srcCorpus = srcCorpus.FilterTexts(textIds);
trgCorpus = trgCorpus.FilterTexts(textIds);
}
IEnumerable<string>? textIds = corpus.PretranslateChapters is not null
? corpus.PretranslateChapters.Keys
: corpus.PretranslateTextIds;
srcCorpus = srcCorpus.FilterTexts(textIds);
trgCorpus = trgCorpus.FilterTexts(textIds);

int rowCount = 0;
StringBuilder srcSegBuffer = new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,23 +285,28 @@ private static Serval.Translation.V1.Phrase Map(Translation.Phrase source)

private static Models.Corpus Map(Serval.Translation.V1.Corpus source)
{
var pretranslateChapters = source.PretranslateChapters.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value.Chapters.ToHashSet()
);
FilterChoice pretranslateFilter = GetFilterChoice(source.PretranslateAll, pretranslateChapters);

var trainOnChapters = source.TrainOnChapters.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value.Chapters.ToHashSet()
);
FilterChoice trainingFilter = GetFilterChoice(source.TrainOnAll, trainOnChapters);

return new Models.Corpus
{
Id = source.Id,
SourceLanguage = source.SourceLanguage,
TargetLanguage = source.TargetLanguage,
TrainOnAll = source.TrainOnAll,
PretranslateAll = source.PretranslateAll,
TrainOnChapters = source.TrainOnChapters.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value.Chapters.ToHashSet()
),
PretranslateChapters = source.PretranslateChapters.ToDictionary(
kvp => kvp.Key,
kvp => kvp.Value.Chapters.ToHashSet()
),
TrainOnTextIds = source.TrainOnTextIds.ToHashSet(),
PretranslateTextIds = source.PretranslateTextIds.ToHashSet(),
TrainOnChapters = trainingFilter == FilterChoice.Chapters ? trainOnChapters : null,
PretranslateChapters = pretranslateFilter == FilterChoice.Chapters ? pretranslateChapters : null,
TrainOnTextIds = trainingFilter == FilterChoice.TextIds ? source.TrainOnTextIds.ToHashSet() : null,
PretranslateTextIds =
pretranslateFilter == FilterChoice.TextIds ? source.PretranslateTextIds.ToHashSet() : null,
SourceFiles = source.SourceFiles.Select(Map).ToList(),
TargetFiles = source.TargetFiles.Select(Map).ToList()
};
Expand All @@ -316,4 +321,24 @@ private static Models.CorpusFile Map(Serval.Translation.V1.CorpusFile source)
TextId = source.TextId
};
}

private enum FilterChoice
{
Chapters,
TextIds,
None
}

private static FilterChoice GetFilterChoice(bool all, IReadOnlyDictionary<string, HashSet<int>> chapters)
{
if (all)
return FilterChoice.None;

// Only either textIds or Scripture Range will be used at a time
// TextIds may be an empty array, so prefer that if both are empty (which applies to both scripture and text)
if (chapters.Count == 0)
return FilterChoice.TextIds;
else
return FilterChoice.Chapters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public async Task RunAsync_FilterOutEverything()
public async Task RunAsync_TrainOnAll()
{
using TestEnvironment env = new();
Corpus corpus1 = env.DefaultTextFileCorpus with { TrainOnAll = true };
Corpus corpus1 = env.DefaultTextFileCorpus with { TrainOnTextIds = null };

await env.RunBuildJobAsync(corpus1);

Expand Down Expand Up @@ -61,7 +61,7 @@ public async Task RunAsync_TrainOnTextIds()
public async Task RunAsync_TrainAndPretranslateAll()
{
using TestEnvironment env = new();
Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateAll = true, TrainOnAll = true };
Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = null, TrainOnTextIds = null };

await env.RunBuildJobAsync(corpus1);

Expand All @@ -72,7 +72,7 @@ public async Task RunAsync_TrainAndPretranslateAll()
public async Task RunAsync_PretranslateAll()
{
using TestEnvironment env = new();
Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateAll = true };
Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = null };

await env.RunBuildJobAsync(corpus1);

Expand All @@ -83,7 +83,7 @@ public async Task RunAsync_PretranslateAll()
public async Task RunAsync_PretranslateTextIds()
{
using TestEnvironment env = new();
Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = ["textId1"], TrainOnAll = true };
Corpus corpus1 = env.DefaultTextFileCorpus with { PretranslateTextIds = ["textId1"], TrainOnTextIds = null };

await env.RunBuildJobAsync(corpus1);

Expand Down Expand Up @@ -177,7 +177,11 @@ public async Task RunAsync_TrainOnChapters()
public async Task RunAsync_MixedSource_Paratext()
{
using TestEnvironment env = new();
Corpus corpus1 = env.DefaultMixedSourceParatextCorpus with { TrainOnAll = true, PretranslateAll = true };
Corpus corpus1 = env.DefaultMixedSourceParatextCorpus with
{
TrainOnTextIds = null,
PretranslateTextIds = null
};

await env.RunBuildJobAsync(corpus1, useKeyTerms: false);

Expand All @@ -196,7 +200,13 @@ public async Task RunAsync_MixedSource_Paratext()
public async Task RunAsync_MixedSource_Text()
{
using TestEnvironment env = new();
Corpus corpus1 = env.DefaultMixedSourceTextFileCorpus with { TrainOnAll = true, PretranslateAll = true };
Corpus corpus1 = env.DefaultMixedSourceTextFileCorpus with
{
TrainOnTextIds = null,
PretranslateTextIds = null,
TrainOnChapters = null,
PretranslateChapters = null
};

await env.RunBuildJobAsync(corpus1);

Expand Down Expand Up @@ -268,8 +278,6 @@ public TestEnvironment()
Id = "corpusId1",
SourceLanguage = "es",
TargetLanguage = "en",
PretranslateAll = false,
TrainOnAll = false,
PretranslateTextIds = [],
TrainOnTextIds = [],
SourceFiles = [TextFile("source1")],
Expand All @@ -281,8 +289,6 @@ public TestEnvironment()
Id = "corpusId1",
SourceLanguage = "es",
TargetLanguage = "en",
PretranslateAll = false,
TrainOnAll = false,
PretranslateTextIds = [],
TrainOnTextIds = [],
SourceFiles = [TextFile("source1"), TextFile("source2")],
Expand All @@ -294,8 +300,6 @@ public TestEnvironment()
Id = "corpusId1",
SourceLanguage = "es",
TargetLanguage = "en",
PretranslateAll = false,
TrainOnAll = false,
PretranslateTextIds = [],
TrainOnTextIds = [],
SourceFiles = [ParatextFile("pt-source1")],
Expand All @@ -307,8 +311,6 @@ public TestEnvironment()
Id = "corpusId1",
SourceLanguage = "es",
TargetLanguage = "en",
PretranslateAll = false,
TrainOnAll = false,
PretranslateTextIds = [],
TrainOnTextIds = [],
SourceFiles = [ParatextFile("pt-source1"), ParatextFile("pt-source2")],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ await env.Service.StartBuildAsync(
TargetLanguage = "en",
SourceFiles = [],
TargetFiles = [],
TrainOnAll = true,
PretranslateAll = true,
TrainOnTextIds = [],
PretranslateTextIds = []
TrainOnTextIds = null,
PretranslateTextIds = null
}
}
);
Expand Down
Loading