diff --git a/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs b/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs index 878d7a5a1..51d7fe1e5 100644 --- a/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs @@ -9,6 +9,7 @@ IInteractiveTranslationModel Create( ITruecaser truecaser ); ITrainer CreateTrainer(string engineId, IRangeTokenizer tokenizer, IParallelTextCorpus corpus); + Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken); Task DownloadBuiltEngineAsync(string engineId, CancellationToken cancellationToken); void InitNew(string engineId); void Cleanup(string engineId); diff --git a/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs index 3daf15159..fc19f0322 100644 --- a/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs @@ -4,15 +4,15 @@ public abstract class PreprocessBuildJob : HangfireBuildJob engines, diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs index 29d1dee03..bc457a34a 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferPostprocessBuildJob.cs @@ -30,7 +30,7 @@ CancellationToken cancellationToken await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) { await _smtModelFactory.DownloadBuiltEngineAsync(engineId, cancellationToken); - int segmentPairsSize = await TrainOnNewSegmentPairs(engineId, @lock, cancellationToken); + int segmentPairsSize = await TrainOnNewSegmentPairs(engineId, cancellationToken); await PlatformService.BuildCompletedAsync( buildId, trainSize: data.Item1 + segmentPairsSize, @@ -43,50 +43,38 @@ await PlatformService.BuildCompletedAsync( Logger.LogInformation("Build completed ({0}).", buildId); } - private async Task TrainOnNewSegmentPairs( - string engineId, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) + private async Task TrainOnNewSegmentPairs(string engineId, CancellationToken cancellationToken) { TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); if (engine is null) throw new OperationCanceledException(); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - cancellationToken.ThrowIfCancellationRequested(); - IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( - p => p.TranslationEngineRef == engine.Id, - CancellationToken.None - ); - if (segmentPairs.Count == 0) - return segmentPairs.Count; + cancellationToken.ThrowIfCancellationRequested(); + IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( + p => p.TranslationEngineRef == engine.Id, + CancellationToken.None + ); + if (segmentPairs.Count == 0) + return segmentPairs.Count; - var tokenizer = new LatinWordTokenizer(); - var detokenizer = new LatinWordDetokenizer(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); + var tokenizer = new LatinWordTokenizer(); + var detokenizer = new LatinWordDetokenizer(); + ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId); - using ( - IInteractiveTranslationModel smtModel = _smtModelFactory.Create( - engineId, - tokenizer, - detokenizer, - truecaser - ) - ) + using ( + IInteractiveTranslationModel smtModel = _smtModelFactory.Create(engineId, tokenizer, detokenizer, truecaser) + ) + { + foreach (TrainSegmentPair segmentPair in segmentPairs) { - foreach (TrainSegmentPair segmentPair in segmentPairs) - { - await smtModel.TrainSegmentAsync( - segmentPair.Source, - segmentPair.Target, - cancellationToken: CancellationToken.None - ); - } - await smtModel.SaveAsync(CancellationToken.None); + await smtModel.TrainSegmentAsync( + segmentPair.Source, + segmentPair.Target, + cancellationToken: CancellationToken.None + ); } - return segmentPairs.Count; + await smtModel.SaveAsync(CancellationToken.None); } + return segmentPairs.Count; } } diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs index da3a44375..7796a6260 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs @@ -3,7 +3,6 @@ public class SmtTransferTrainBuildJob( IPlatformService platformService, IRepository engines, - IOptionsMonitor engineOptions, IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, ILogger logger, @@ -15,7 +14,6 @@ ISmtModelFactory smtModelFactory private readonly ISharedFileService _sharedFileService = sharedFileService; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; - private readonly IOptionsMonitor _engineOptions = engineOptions; protected override async Task DoWorkAsync( string engineId, @@ -49,17 +47,12 @@ CancellationToken cancellationToken await smtModelTrainer.TrainAsync(progress, cancellationToken); await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); - string modelDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); - Directory.CreateDirectory(modelDir); - cancellationToken.ThrowIfCancellationRequested(); + await smtModelTrainer.SaveAsync(CancellationToken.None); await truecaseTrainer.SaveAsync(CancellationToken.None); - // save model to S3 bucket - Stream modelDst = await _sharedFileService.OpenWriteAsync($"models/{engineId}.zip"); - ZipFile.CreateFromDirectory(modelDir, modelDst); - modelDst.Close(); + await _smtModelFactory.UploadBuiltEngineAsync(engineId, cancellationToken); cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs b/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs index a0d88f646..29a58df3d 100644 --- a/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs @@ -57,6 +57,14 @@ public async Task DownloadBuiltEngineAsync(string engineId, CancellationToken ca await _sharedFileService.DeleteAsync(sharedFilePath); } + public async Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken) + { + string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); + string sharedFilePath = $"models/{engineId}.zip"; + using Stream sharedStream = await _sharedFileService.OpenWriteAsync(sharedFilePath, cancellationToken); + ZipFile.CreateFromDirectory(engineDir, sharedStream); + } + public void InitNew(string engineId) { string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index 8a8e5e85b..89d15a586 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -56,6 +56,14 @@ await env.Service.StartBuildAsync( ] ); await env.WaitForBuildToFinishAsync(); + await env + .SmtBatchTrainer.Received() + .TrainAsync(Arg.Any>(), Arg.Any()); + await env + .TruecaserTrainer.Received() + .TrainAsync(Arg.Any>(), Arg.Any()); + await env.SmtBatchTrainer.Received().SaveAsync(Arg.Any()); + await env.TruecaserTrainer.Received().SaveAsync(Arg.Any()); engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); Assert.That(engine.BuildRevision, Is.EqualTo(2)); @@ -119,11 +127,10 @@ await env.SmtBatchTrainer.TrainAsync( }) ); await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); - await env.WaitForBuildToStartAsync(); + await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active)); - await Task.Delay(200); env.StopServer(); await env.WaitForBuildToRestartAsync(); engine = env.Engines.Get(EngineId1); @@ -270,17 +277,7 @@ public TestEnvironment() ); Truecaser = Substitute.For(); TruecaserTrainer = Substitute.For(); - EngineOptions = Substitute.For>(); - DirectoryInfo tempDir = Directory.CreateTempSubdirectory(); - EngineOptions.CurrentValue.Returns(new SmtTransferEngineOptions() { EnginesDir = tempDir.FullName }); - Task SaveTrueCaserAsync() - { - using (File.Create(Path.Combine(tempDir.FullName, EngineId1, "unigram-casing-model.txt"))) { } - return Task.CompletedTask; - } - - TruecaserTrainer.When(x => x.SaveAsync()).Do(_ => SaveTrueCaserAsync()); SmtModelFactory = CreateSmtModelFactory(); TransferEngineFactory = CreateTransferEngineFactory(); _truecaserFactory = CreateTruecaserFactory(); @@ -356,7 +353,6 @@ [new SmtTransferClearMLBuildJobFactory(SharedFileService, Engines)], public IClearMLQueueService ClearMLMonitorService { get; } public ISharedFileService SharedFileService { get; } - public IOptionsMonitor EngineOptions { get; } public IBuildJobService BuildJobService { get; } public Func TrainJobFunc { get; set; } @@ -370,8 +366,8 @@ public async Task CommitAsync(TimeSpan inactiveTimeout) public void StopServer() { - StateService.Dispose(); _jobServer.Dispose(); + StateService.Dispose(); } public void StartServer() @@ -580,6 +576,13 @@ public Task WaitForBuildToStartAsync() return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Active); } + public Task WaitForTrainingToStartAsync() + { + return WaitForBuildState(e => + e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage is BuildStage.Train + ); + } + public Task WaitForBuildToRestartAsync() { return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Pending); @@ -614,44 +617,22 @@ private class EnvActivator(TestEnvironment env) : JobActivator { private readonly TestEnvironment _env = env; - public class SmtTransferPreprocessBuildJobTest : SmtTransferPreprocessBuildJob - { - public SmtTransferPreprocessBuildJobTest( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - ILogger logger, - IBuildJobService buildJobService, - ISharedFileService sharedFileService, - ICorpusService corpusService - ) - : base( - platformService, - engines, - lockFactory, - logger, - buildJobService, - sharedFileService, - corpusService - ) - { - TrainJobRunnerType = BuildJobRunnerType.Hangfire; - } - } - public override object ActivateJob(Type jobType) { if (jobType == typeof(SmtTransferPreprocessBuildJob)) { - return new SmtTransferPreprocessBuildJobTest( + return new SmtTransferPreprocessBuildJob( _env.PlatformService, _env.Engines, _env._lockFactory, - Substitute.For>(), + Substitute.For>(), _env.BuildJobService, _env.SharedFileService, Substitute.For() - ); + ) + { + TrainJobRunnerType = BuildJobRunnerType.Hangfire + }; } if (jobType == typeof(SmtTransferPostprocessBuildJob)) { @@ -672,7 +653,6 @@ public override object ActivateJob(Type jobType) return new SmtTransferTrainBuildJob( _env.PlatformService, _env.Engines, - _env.EngineOptions, _env._lockFactory, _env.BuildJobService, Substitute.For>(),