diff --git a/src/SIL.Machine.AspNetCore/Models/Corpus.cs b/src/SIL.Machine.AspNetCore/Models/Corpus.cs index 84dda461f..c33bc52ce 100644 --- a/src/SIL.Machine.AspNetCore/Models/Corpus.cs +++ b/src/SIL.Machine.AspNetCore/Models/Corpus.cs @@ -5,7 +5,9 @@ public class Corpus public string Id { get; set; } = default!; public string SourceLanguage { get; set; } = default!; public string TargetLanguage { get; set; } = default!; + public bool TrainOnAll { get; set; } public bool PretranslateAll { get; set; } + public HashSet TrainOnTextIds { get; set; } = default!; public HashSet PretranslateTextIds { get; set; } = default!; public List SourceFiles { get; set; } = default!; public List TargetFiles { get; set; } = default!; diff --git a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs index e1202dec8..3fbc41ee0 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtPreprocessBuildJob.cs @@ -76,8 +76,11 @@ async IAsyncEnumerable ProcessRowsAsync() foreach (ParallelTextRow row in parallelCorpus) { - await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); - await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); + if (corpus.TrainOnAll || corpus.TrainOnTextIds.Contains(row.TextId)) + { + await sourceTrainWriter.WriteAsync($"{row.SourceText}\n"); + await targetTrainWriter.WriteAsync($"{row.TargetText}\n"); + } if ( (corpus.PretranslateAll || corpus.PretranslateTextIds.Contains(row.TextId)) && row.SourceSegment.Count > 0 diff --git a/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs b/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs index fc2173053..35329ffc0 100644 --- a/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs +++ b/src/SIL.Machine.AspNetCore/Services/S3WriteStream.cs @@ -55,37 +55,40 @@ public override void Flush() { } public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - try + if (count > 0) { - using MemoryStream ms = new(buffer, offset, count); - int partNumber = _uploadResponses.Count + 1; - UploadPartRequest request = - new() - { - BucketName = _bucketName, - Key = _key, - UploadId = _uploadId, - PartNumber = partNumber, - InputStream = ms, - PartSize = MaxPartSize - }; - request.StreamTransferProgress += new EventHandler( - (_, e) => - { - _logger.LogDebug($"Transferred {e.TransferredBytes}/{e.TotalBytes}"); - } - ); - UploadPartResponse response = await _client.UploadPartAsync(request); - if (response.HttpStatusCode != HttpStatusCode.OK) - throw new HttpRequestException( - $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + try + { + using MemoryStream ms = new(buffer, offset, count); + int partNumber = _uploadResponses.Count + 1; + UploadPartRequest request = + new() + { + BucketName = _bucketName, + Key = _key, + UploadId = _uploadId, + PartNumber = partNumber, + InputStream = ms, + PartSize = MaxPartSize + }; + request.StreamTransferProgress += new EventHandler( + (_, e) => + { + _logger.LogDebug($"Transferred {e.TransferredBytes}/{e.TotalBytes}"); + } ); - _uploadResponses.Add(response); - } - catch (Exception e) - { - await AbortAsync(e); - throw; + UploadPartResponse response = await _client.UploadPartAsync(request); + if (response.HttpStatusCode != HttpStatusCode.OK) + throw new HttpRequestException( + $"Tried to upload part {partNumber} of upload {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + ); + _uploadResponses.Add(response); + } + catch (Exception e) + { + await AbortAsync(e); + throw; + } } } @@ -93,30 +96,49 @@ protected override void Dispose(bool disposing) { if (disposing) { - try + if (_uploadResponses.Count == 0) { - CompleteMultipartUploadRequest request = + AbortAsync().WaitAndUnwrapException(); + PutObjectRequest request = new() { BucketName = _bucketName, Key = _key, - UploadId = _uploadId + ContentBody = "" }; - request.AddPartETags(_uploadResponses); - CompleteMultipartUploadResponse response = _client - .CompleteMultipartUploadAsync(request) - .WaitAndUnwrapException(); - Dispose(disposing: false); - GC.SuppressFinalize(this); + PutObjectResponse response = _client.PutObjectAsync(request).WaitAndUnwrapException(); if (response.HttpStatusCode != HttpStatusCode.OK) throw new HttpRequestException( - $"Tried to complete {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + $"Tried to upload empty file to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" ); } - catch (Exception e) + else { - AbortAsync(e).WaitAndUnwrapException(); - throw; + try + { + CompleteMultipartUploadRequest request = + new() + { + BucketName = _bucketName, + Key = _key, + UploadId = _uploadId + }; + request.AddPartETags(_uploadResponses); + CompleteMultipartUploadResponse response = _client + .CompleteMultipartUploadAsync(request) + .WaitAndUnwrapException(); + Dispose(disposing: false); + GC.SuppressFinalize(this); + if (response.HttpStatusCode != HttpStatusCode.OK) + throw new HttpRequestException( + $"Tried to complete {_uploadId} to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + ); + } + catch (Exception e) + { + AbortAsync(e).WaitAndUnwrapException(); + throw; + } } } base.Dispose(disposing); @@ -124,6 +146,23 @@ protected override void Dispose(bool disposing) public async override ValueTask DisposeAsync() { + if (_uploadResponses.Count == 0) + { + await AbortAsync(); + PutObjectRequest request = + new() + { + BucketName = _bucketName, + Key = _key, + ContentBody = "" + }; + PutObjectResponse response = await _client.PutObjectAsync(request); + if (response.HttpStatusCode != HttpStatusCode.OK) + throw new HttpRequestException( + $"Tried to upload empty file to {_bucketName}/{_key} but received response code {response.HttpStatusCode}" + ); + return; + } try { CompleteMultipartUploadRequest request = @@ -148,9 +187,10 @@ public async override ValueTask DisposeAsync() } } - private async Task AbortAsync(Exception e) + private async Task AbortAsync(Exception? e = null) { - _logger.LogError(e, $"Aborted upload {_uploadId} to {_bucketName}/{_key}"); + if (e is not null) + _logger.LogError(e, $"Aborted upload {_uploadId} to {_bucketName}/{_key}"); AbortMultipartUploadRequest abortMPURequest = new() { diff --git a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs index f57038f01..a0b010c61 100644 --- a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs +++ b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs @@ -236,7 +236,9 @@ private static Models.Corpus Map(Serval.Translation.V1.Corpus source) Id = source.Id, SourceLanguage = source.SourceLanguage, TargetLanguage = source.TargetLanguage, + TrainOnAll = source.TrainOnAll, PretranslateAll = source.PretranslateAll, + TrainOnTextIds = source.TrainOnTextIds.ToHashSet(), PretranslateTextIds = source.PretranslateTextIds.ToHashSet(), SourceFiles = source.SourceFiles.Select(Map).ToList(), TargetFiles = source.TargetFiles.Select(Map).ToList()