Skip to content

Commit

Permalink
Refactor alignment data structure; revert to pretranslate/word-align …
Browse files Browse the repository at this point in the history
…where appropriate
  • Loading branch information
Enkidu93 committed Jan 9, 2025
1 parent bbc3248 commit 97b599f
Showing 20 changed files with 75 additions and 74 deletions.
13 changes: 6 additions & 7 deletions src/Echo/src/EchoEngine/TranslationEngineServiceV1.cs
Original file line number Diff line number Diff line change
@@ -78,9 +78,8 @@ await client.BuildStartedAsync(
try
{
using (
AsyncClientStreamingCall<InsertInferencesRequest, Empty> call = client.InsertInferences(
cancellationToken: cancellationToken
)
AsyncClientStreamingCall<InsertPretranslationsRequest, Empty> call =
client.InsertPretranslations(cancellationToken: cancellationToken)
)
{
foreach (ParallelCorpus corpus in request.Corpora)
@@ -133,7 +132,7 @@ await client.BuildStartedAsync(
if (sourceLine.Length > 0 && targetLine.Length == 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
@@ -166,7 +165,7 @@ await call.RequestStream.WriteAsync(
if (sourceLine.Length > 0 && targetLine.Length == 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
@@ -191,7 +190,7 @@ await call.RequestStream.WriteAsync(
if (sourceLine.Length > 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
@@ -212,7 +211,7 @@ await call.RequestStream.WriteAsync(
if (sourceLine.Length > 0)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
6 changes: 3 additions & 3 deletions src/Echo/src/EchoEngine/WordAlignmentEngineServiceV1.cs
Original file line number Diff line number Diff line change
@@ -69,7 +69,7 @@ await client.BuildStartedAsync(
try
{
using (
AsyncClientStreamingCall<InsertInferencesRequest, Empty> call = client.InsertInferences(
AsyncClientStreamingCall<InsertWordAlignmentsRequest, Empty> call = client.InsertWordAlignments(
cancellationToken: cancellationToken
)
)
@@ -128,7 +128,7 @@ await client.BuildStartedAsync(
targetLine.Split().Length
);
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertWordAlignmentsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
@@ -168,7 +168,7 @@ await call.RequestStream.WriteAsync(
targetLine.Split().Length
);
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertWordAlignmentsRequest
{
EngineId = request.EngineId,
CorpusId = corpus.Id,
Original file line number Diff line number Diff line change
@@ -75,6 +75,15 @@ public static IMachineBuilder AddThotSmtModel(this IMachineBuilder builder, ICon
return builder;
}

public static IMachineBuilder AddWordAlignmentModel(this IMachineBuilder builder)
{
builder.Services.Configure<WordAlignmentModelOptions>(
builder.Configuration.GetSection(WordAlignmentModelOptions.Key)
);
builder.Services.AddSingleton<IWordAlignmentModelFactory, WordAlignmentModelFactory>();
return builder;
}

public static IMachineBuilder AddTransferEngine(this IMachineBuilder builder)
{
builder.Services.AddSingleton<ITransferEngineFactory, TransferEngineFactory>();
@@ -485,7 +494,7 @@ public static IMachineBuilder AddThot(this IMachineBuilder builder)
{
try
{
builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser();
builder.AddThotSmtModel().AddTransferEngine().AddUnigramTruecaser().AddWordAlignmentModel();
}
catch (ArgumentException)
{
@@ -516,9 +525,9 @@ public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder)
var smtTransferEngineOptions = new SmtTransferEngineOptions();
builder.Configuration.GetSection(SmtTransferEngineOptions.Key).Bind(smtTransferEngineOptions);
string? smtDriveLetter = Path.GetPathRoot(smtTransferEngineOptions.EnginesDir)?[..1];
var statisticsEngineOptions = new WordAlignmentEngineOptions();
builder.Configuration.GetSection(WordAlignmentEngineOptions.Key).Bind(statisticsEngineOptions);
string? statisticsDriveLetter = Path.GetPathRoot(statisticsEngineOptions.EnginesDir)?[..1];
var statisticalEngineOptions = new WordAlignmentEngineOptions();
builder.Configuration.GetSection(WordAlignmentEngineOptions.Key).Bind(statisticalEngineOptions);
string? statisticsDriveLetter = Path.GetPathRoot(statisticalEngineOptions.EnginesDir)?[..1];
if (smtDriveLetter is null || statisticsDriveLetter is null)
throw new InvalidOperationException("SMT Engine and Statistical directory is required");
if (smtDriveLetter != statisticsDriveLetter)
Original file line number Diff line number Diff line change
@@ -7,6 +7,6 @@ public record WordAlignment
public required IReadOnlyList<string> Refs { get; init; }
public required IReadOnlyList<string> SourceTokens { get; set; }
public required IReadOnlyList<string> TargetTokens { get; set; }
public required IReadOnlyList<double> Confidences { get; set; }
public required IReadOnlyList<double> Confidences { get; set; } //TODO It seems to me that it'd more natural to have the confidence as part of the word pair object - but I understand that this is currently not the case with the translation result; would it be breaking to change it there too?
public required IReadOnlyList<AlignedWordPair> Alignment { get; set; }
}
Original file line number Diff line number Diff line change
@@ -59,12 +59,12 @@ await _client.BuildRestartingAsync(
)
.OfType<Pretranslation>();

using (var call = _client.InsertInferences(cancellationToken: cancellationToken))
using (var call = _client.InsertPretranslations(cancellationToken: cancellationToken))
{
await foreach (Pretranslation pretranslation in pretranslations)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = content!,
CorpusId = pretranslation.CorpusId,
Original file line number Diff line number Diff line change
@@ -122,21 +122,18 @@ private static WordAlignment.V1.WordAlignmentResult Map(SIL.Machine.Translation.
{
SourceTokens = { source.SourceTokens },
TargetTokens = { source.TargetTokens },
Alignment = { Map(source.Alignment) },
Confidences = { source.Confidences }
Confidences = { source.Confidences },
Alignment = { source.AlignedWordPairs.Select(Map) }

Check failure on line 126 in src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs

GitHub Actions / Build

'WordAlignmentResult' does not contain a definition for 'AlignedWordPairs' and no accessible extension method 'AlignedWordPairs' accepting a first argument of type 'WordAlignmentResult' could be found (are you missing a using directive or an assembly reference?)

Check failure on line 126 in src/Machine/src/Serval.Machine.Shared/Services/ServalWordAlignmentEngineServiceV1.cs

GitHub Actions / Build

'WordAlignmentResult' does not contain a definition for 'AlignedWordPairs' and no accessible extension method 'AlignedWordPairs' accepting a first argument of type 'WordAlignmentResult' could be found (are you missing a using directive or an assembly reference?)
};
}

private static IEnumerable<WordAlignment.V1.AlignedWordPair> Map(WordAlignmentMatrix source)
private static WordAlignment.V1.AlignedWordPair Map(SIL.Machine.Corpora.AlignedWordPair source)
{
for (int i = 0; i < source.RowCount; i++)
return new WordAlignment.V1.AlignedWordPair
{
for (int j = 0; j < source.ColumnCount; j++)
{
if (source[i, j])
yield return new WordAlignment.V1.AlignedWordPair { SourceIndex = i, TargetIndex = j };
}
}
SourceIndex = source.SourceIndex,
TargetIndex = source.TargetIndex
};
}

private static SIL.ServiceToolkit.Models.ParallelCorpus Map(WordAlignment.V1.ParallelCorpus source)
Original file line number Diff line number Diff line change
@@ -62,12 +62,12 @@ await _client.BuildRestartingAsync(
)
.OfType<Models.WordAlignment>();

using (var call = _client.InsertInferences(cancellationToken: cancellationToken))
using (var call = _client.InsertWordAlignments(cancellationToken: cancellationToken))
{
await foreach (Models.WordAlignment wordAlignment in wordAlignments)
{
await call.RequestStream.WriteAsync(
new InsertInferencesRequest
new InsertWordAlignmentsRequest
{
EngineId = content!,
CorpusId = wordAlignment.CorpusId,
Original file line number Diff line number Diff line change
@@ -80,21 +80,15 @@ public async Task<WordAlignmentResult> GetBestWordAlignmentAsync(
return new WordAlignmentResult(
sourceTokens: sourceTokens,
targetTokens: targetTokens,
alignment: new WordAlignmentMatrix(
sourceTokens.Count,
targetTokens.Count,
wordPairs.Select(wp => (wp.SourceIndex, wp.TargetIndex))
),
confidences: wordPairs.Select(wp => wp.AlignmentScore * wp.TranslationScore).ToList()
alignedWordPairs: wordPairs,

Check failure on line 83 in src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs

GitHub Actions / Build

The best overload for 'WordAlignmentResult' does not have a parameter named 'alignedWordPairs'

Check failure on line 83 in src/Machine/src/Serval.Machine.Shared/Services/StatisticalEngineService.cs

GitHub Actions / Build

The best overload for 'WordAlignmentResult' does not have a parameter named 'alignedWordPairs'
confidences: wordPairs.Select(wp => wp.AlignmentScore).ToList()
);
},
cancellationToken: cancellationToken
);

state.Touch();
return result;

throw new NotImplementedException();
}

public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default)
Original file line number Diff line number Diff line change
@@ -48,13 +48,12 @@ await ParallelCorpusPreprocessingService.PreprocessAsync(
corpora,
async row =>
{
if (row.SourceSegment.Length > 0 || row.TargetSegment.Length > 0)
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
{
await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n");
await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n");
}
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
trainCount++;
}
},
async (row, corpus) =>
{
Original file line number Diff line number Diff line change
@@ -51,10 +51,10 @@ await env.Handler.HandleMessageAsync(
);
}

_ = env.Client.Received(1).InsertInferences();
_ = env.Client.Received(1).InsertPretranslations();
_ = env.PretranslationWriter.Received(1)
.WriteAsync(
new InsertInferencesRequest
new InsertPretranslationsRequest
{
EngineId = "engine1",
CorpusId = "corpus1",
@@ -78,9 +78,9 @@ public TestEnvironment()
Client
.IncrementTrainEngineCorpusSizeAsync(Arg.Any<IncrementTrainEngineCorpusSizeRequest>())
.Returns(CreateEmptyUnaryCall());
PretranslationWriter = Substitute.For<IClientStreamWriter<InsertInferencesRequest>>();
PretranslationWriter = Substitute.For<IClientStreamWriter<InsertPretranslationsRequest>>();
Client
.InsertInferences(cancellationToken: Arg.Any<CancellationToken>())
.InsertPretranslations(cancellationToken: Arg.Any<CancellationToken>())
.Returns(
TestCalls.AsyncClientStreamingCall(
PretranslationWriter,
@@ -97,7 +97,7 @@ public TestEnvironment()

public TranslationPlatformApi.TranslationPlatformApiClient Client { get; }
public ServalTranslationPlatformOutboxMessageHandler Handler { get; }
public IClientStreamWriter<InsertInferencesRequest> PretranslationWriter { get; }
public IClientStreamWriter<InsertPretranslationsRequest> PretranslationWriter { get; }

private static AsyncUnaryCall<Empty> CreateEmptyUnaryCall()
{
Original file line number Diff line number Diff line change
@@ -135,7 +135,8 @@ public async Task GetBestWordAlignment()
);
Assert.That(string.Join(' ', result.TargetTokens), Is.EqualTo("this is a test ."));
Assert.That(result.Confidences, Has.Count.EqualTo(5));
Assert.That(result.Alignment[0, 0], Is.True);
Assert.That(result.AlignedWordPairs.First().SourceIndex, Is.EqualTo(0));
Assert.That(result.AlignedWordPairs.First().TargetIndex, Is.EqualTo(0));
}

private class TestEnvironment : DisposableBase
8 changes: 4 additions & 4 deletions src/Serval/src/Serval.Client/Client.g.cs
Original file line number Diff line number Diff line change
@@ -7137,8 +7137,8 @@ public partial interface IWordAlignmentEnginesClient
/// <br/> * An auto-generated reference of `[TextId]:[lineNumber]`, 1 indexed.
/// <br/>* **SourceTokens**: the tokenized source segment
/// <br/>* **TargetTokens**: the tokenized target segment
/// <br/>* **Confidences**: the confidence of the alignment ona scale from 0 to 1
/// <br/>* **Alignment**: the word alignment, 0 indexed for source and target positions
/// <br/>* **Confidences**: the confidence of the alignment on a scale from 0 to 1
/// <br/>* **Alignment**: a list of aligned word pairs
/// <br/>
/// <br/>Word alignments can be filtered by text id if provided.
/// <br/>Only word alignments for the most recent successful build of the engine are returned.
@@ -8406,8 +8406,8 @@ public string BaseUrl
/// <br/> * An auto-generated reference of `[TextId]:[lineNumber]`, 1 indexed.
/// <br/>* **SourceTokens**: the tokenized source segment
/// <br/>* **TargetTokens**: the tokenized target segment
/// <br/>* **Confidences**: the confidence of the alignment ona scale from 0 to 1
/// <br/>* **Alignment**: the word alignment, 0 indexed for source and target positions
/// <br/>* **Confidences**: the confidence of the alignment on a scale from 0 to 1
/// <br/>* **Alignment**: a list of aligned word pairs
/// <br/>
/// <br/>Word alignments can be filtered by text id if provided.
/// <br/>Only word alignments for the most recent successful build of the engine are returned.
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ service TranslationPlatformApi {
rpc BuildRestarting(BuildRestartingRequest) returns (google.protobuf.Empty);

rpc IncrementTrainEngineCorpusSize(IncrementTrainEngineCorpusSizeRequest) returns (google.protobuf.Empty);
rpc InsertInferences(stream InsertInferencesRequest) returns (google.protobuf.Empty);
rpc InsertPretranslations(stream InsertPretranslationsRequest) returns (google.protobuf.Empty);
}

message UpdateBuildStatusRequest {
@@ -52,7 +52,7 @@ message IncrementTrainEngineCorpusSizeRequest {
int32 count = 2;
}

message InsertInferencesRequest {
message InsertPretranslationsRequest {
string engine_id = 1;
string corpus_id = 2;
string text_id = 3;
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ service WordAlignmentPlatformApi {
rpc BuildRestarting(BuildRestartingRequest) returns (google.protobuf.Empty);

rpc IncrementTrainEngineCorpusSize(IncrementTrainEngineCorpusSizeRequest) returns (google.protobuf.Empty);
rpc InsertInferences(stream InsertInferencesRequest) returns (google.protobuf.Empty);
rpc InsertWordAlignments(stream InsertWordAlignmentsRequest) returns (google.protobuf.Empty);
}

message UpdateBuildStatusRequest {
@@ -54,7 +54,7 @@ message IncrementTrainEngineCorpusSizeRequest {
int32 count = 2;
}

message InsertInferencesRequest {
message InsertWordAlignmentsRequest {
string engine_id = 1;
string corpus_id = 2;
string text_id = 3;
Original file line number Diff line number Diff line change
@@ -278,16 +278,16 @@ await _engines.UpdateAsync(
return Empty;
}

public override async Task<Empty> InsertInferences(
IAsyncStreamReader<InsertInferencesRequest> requestStream,
public override async Task<Empty> InsertPretranslations(
IAsyncStreamReader<InsertPretranslationsRequest> requestStream,
ServerCallContext context
)
{
string engineId = "";
int nextModelRevision = 0;

var batch = new List<Pretranslation>();
await foreach (InsertInferencesRequest request in requestStream.ReadAllAsync(context.CancellationToken))
await foreach (InsertPretranslationsRequest request in requestStream.ReadAllAsync(context.CancellationToken))
{
if (request.EngineId != engineId)
{
Original file line number Diff line number Diff line change
@@ -387,8 +387,8 @@ CancellationToken cancellationToken
/// * An auto-generated reference of `[TextId]:[lineNumber]`, 1 indexed.
/// * **SourceTokens**: the tokenized source segment
/// * **TargetTokens**: the tokenized target segment
/// * **Confidences**: the confidence of the alignment ona scale from 0 to 1
/// * **Alignment**: the word alignment, 0 indexed for source and target positions
/// * **Confidences**: the confidence of the alignment on a scale from 0 to 1
/// * **Alignment**: a list of aligned word pairs
///
/// Word alignments can be filtered by text id if provided.
/// Only word alignments for the most recent successful build of the engine are returned.
Original file line number Diff line number Diff line change
@@ -278,16 +278,16 @@ await _engines.UpdateAsync(
return Empty;
}

public override async Task<Empty> InsertInferences(
IAsyncStreamReader<InsertInferencesRequest> requestStream,
public override async Task<Empty> InsertWordAlignments(
IAsyncStreamReader<InsertWordAlignmentsRequest> requestStream,
ServerCallContext context
)
{
string engineId = "";
int nextModelRevision = 0;

var batch = new List<Models.WordAlignment>();
await foreach (InsertInferencesRequest request in requestStream.ReadAllAsync(context.CancellationToken))
await foreach (InsertWordAlignmentsRequest request in requestStream.ReadAllAsync(context.CancellationToken))
{
if (request.EngineId != engineId)
{
Loading

0 comments on commit 97b599f

Please sign in to comment.