Skip to content

Commit

Permalink
reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed May 22, 2024
1 parent ea8f6ae commit b6ac585
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 92 deletions.
1 change: 1 addition & 0 deletions src/SIL.Machine.AspNetCore/Services/ISmtModelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ IInteractiveTranslationModel Create(
ITruecaser truecaser
);
ITrainer CreateTrainer(string engineId, IRangeTokenizer<string, int, string> tokenizer, IParallelTextCorpus corpus);
Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken);
Task DownloadBuiltEngineAsync(string engineId, CancellationToken cancellationToken);
void InitNew(string engineId);
void Cleanup(string engineId);
Expand Down
8 changes: 4 additions & 4 deletions src/SIL.Machine.AspNetCore/Services/PreprocessBuildJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ public abstract class PreprocessBuildJob : HangfireBuildJob<IReadOnlyList<Models
{
private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true };

public BuildJobRunnerType TrainJobRunnerType { get; protected init; } = BuildJobRunnerType.ClearML;
public BuildJobRunnerType TrainJobRunnerType { get; init; } = BuildJobRunnerType.ClearML;
protected TranslationEngineType EngineType { get; init; }
protected bool PretranslationEnabled { get; init; }

private readonly ISharedFileService _sharedFileService;
private readonly ICorpusService _corpusService;
private int _seed = 1234;
private Random _random;

protected TranslationEngineType EngineType { get; set; }
protected bool PretranslationEnabled { get; set; }

public PreprocessBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ CancellationToken cancellationToken
await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None))
{
await _smtModelFactory.DownloadBuiltEngineAsync(engineId, cancellationToken);
int segmentPairsSize = await TrainOnNewSegmentPairs(engineId, @lock, cancellationToken);
int segmentPairsSize = await TrainOnNewSegmentPairs(engineId, cancellationToken);
await PlatformService.BuildCompletedAsync(
buildId,
trainSize: data.Item1 + segmentPairsSize,
Expand All @@ -43,50 +43,38 @@ await PlatformService.BuildCompletedAsync(
Logger.LogInformation("Build completed ({0}).", buildId);
}

private async Task<int> TrainOnNewSegmentPairs(
string engineId,
IDistributedReaderWriterLock @lock,
CancellationToken cancellationToken
)
private async Task<int> TrainOnNewSegmentPairs(string engineId, CancellationToken cancellationToken)
{
TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new OperationCanceledException();

await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken))
{
cancellationToken.ThrowIfCancellationRequested();
IReadOnlyList<TrainSegmentPair> segmentPairs = await _trainSegmentPairs.GetAllAsync(
p => p.TranslationEngineRef == engine.Id,
CancellationToken.None
);
if (segmentPairs.Count == 0)
return segmentPairs.Count;
cancellationToken.ThrowIfCancellationRequested();
IReadOnlyList<TrainSegmentPair> segmentPairs = await _trainSegmentPairs.GetAllAsync(
p => p.TranslationEngineRef == engine.Id,
CancellationToken.None
);
if (segmentPairs.Count == 0)
return segmentPairs.Count;

var tokenizer = new LatinWordTokenizer();
var detokenizer = new LatinWordDetokenizer();
ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId);
var tokenizer = new LatinWordTokenizer();
var detokenizer = new LatinWordDetokenizer();
ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId);

using (
IInteractiveTranslationModel smtModel = _smtModelFactory.Create(
engineId,
tokenizer,
detokenizer,
truecaser
)
)
using (
IInteractiveTranslationModel smtModel = _smtModelFactory.Create(engineId, tokenizer, detokenizer, truecaser)
)
{
foreach (TrainSegmentPair segmentPair in segmentPairs)
{
foreach (TrainSegmentPair segmentPair in segmentPairs)
{
await smtModel.TrainSegmentAsync(
segmentPair.Source,
segmentPair.Target,
cancellationToken: CancellationToken.None
);
}
await smtModel.SaveAsync(CancellationToken.None);
await smtModel.TrainSegmentAsync(
segmentPair.Source,
segmentPair.Target,
cancellationToken: CancellationToken.None
);
}
return segmentPairs.Count;
await smtModel.SaveAsync(CancellationToken.None);
}
return segmentPairs.Count;
}
}
11 changes: 2 additions & 9 deletions src/SIL.Machine.AspNetCore/Services/SmtTransferTrainBuildJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
public class SmtTransferTrainBuildJob(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IOptionsMonitor<SmtTransferEngineOptions> engineOptions,
IDistributedReaderWriterLockFactory lockFactory,
IBuildJobService buildJobService,
ILogger<SmtTransferTrainBuildJob> logger,
Expand All @@ -15,7 +14,6 @@ ISmtModelFactory smtModelFactory
private readonly ISharedFileService _sharedFileService = sharedFileService;
private readonly ITruecaserFactory _truecaserFactory = truecaserFactory;
private readonly ISmtModelFactory _smtModelFactory = smtModelFactory;
private readonly IOptionsMonitor<SmtTransferEngineOptions> _engineOptions = engineOptions;

protected override async Task DoWorkAsync(
string engineId,
Expand Down Expand Up @@ -49,17 +47,12 @@ CancellationToken cancellationToken
await smtModelTrainer.TrainAsync(progress, cancellationToken);
await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken);

string modelDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId);
Directory.CreateDirectory(modelDir);

cancellationToken.ThrowIfCancellationRequested();

await smtModelTrainer.SaveAsync(CancellationToken.None);
await truecaseTrainer.SaveAsync(CancellationToken.None);

// save model to S3 bucket
Stream modelDst = await _sharedFileService.OpenWriteAsync($"models/{engineId}.zip");
ZipFile.CreateFromDirectory(modelDir, modelDst);
modelDst.Close();
await _smtModelFactory.UploadBuiltEngineAsync(engineId, cancellationToken);

cancellationToken.ThrowIfCancellationRequested();

Expand Down
8 changes: 8 additions & 0 deletions src/SIL.Machine.AspNetCore/Services/ThotSmtModelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ public async Task DownloadBuiltEngineAsync(string engineId, CancellationToken ca
await _sharedFileService.DeleteAsync(sharedFilePath);
}

public async Task UploadBuiltEngineAsync(string engineId, CancellationToken cancellationToken)
{
string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId);
string sharedFilePath = $"models/{engineId}.zip";
using Stream sharedStream = await _sharedFileService.OpenWriteAsync(sharedFilePath, cancellationToken);
ZipFile.CreateFromDirectory(engineDir, sharedStream);
}

public void InitNew(string engineId)
{
string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ await env.Service.StartBuildAsync(
]
);
await env.WaitForBuildToFinishAsync();
await env
.SmtBatchTrainer.Received()
.TrainAsync(Arg.Any<IProgress<ProgressStatus>>(), Arg.Any<CancellationToken>());
await env
.TruecaserTrainer.Received()
.TrainAsync(Arg.Any<IProgress<ProgressStatus>>(), Arg.Any<CancellationToken>());
await env.SmtBatchTrainer.Received().SaveAsync(Arg.Any<CancellationToken>());
await env.TruecaserTrainer.Received().SaveAsync(Arg.Any<CancellationToken>());
engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Null);
Assert.That(engine.BuildRevision, Is.EqualTo(2));
Expand Down Expand Up @@ -119,11 +127,10 @@ await env.SmtBatchTrainer.TrainAsync(
})
);
await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty<Corpus>());
await env.WaitForBuildToStartAsync();
await env.WaitForTrainingToStartAsync();
TranslationEngine engine = env.Engines.Get(EngineId1);
Assert.That(engine.CurrentBuild, Is.Not.Null);
Assert.That(engine.CurrentBuild.JobState, Is.EqualTo(BuildJobState.Active));
await Task.Delay(200);
env.StopServer();
await env.WaitForBuildToRestartAsync();
engine = env.Engines.Get(EngineId1);
Expand Down Expand Up @@ -270,17 +277,7 @@ public TestEnvironment()
);
Truecaser = Substitute.For<ITruecaser>();
TruecaserTrainer = Substitute.For<ITrainer>();
EngineOptions = Substitute.For<IOptionsMonitor<SmtTransferEngineOptions>>();
DirectoryInfo tempDir = Directory.CreateTempSubdirectory();
EngineOptions.CurrentValue.Returns(new SmtTransferEngineOptions() { EnginesDir = tempDir.FullName });

Task SaveTrueCaserAsync()
{
using (File.Create(Path.Combine(tempDir.FullName, EngineId1, "unigram-casing-model.txt"))) { }
return Task.CompletedTask;
}

TruecaserTrainer.When(x => x.SaveAsync()).Do(_ => SaveTrueCaserAsync());
SmtModelFactory = CreateSmtModelFactory();
TransferEngineFactory = CreateTransferEngineFactory();
_truecaserFactory = CreateTruecaserFactory();
Expand Down Expand Up @@ -356,7 +353,6 @@ [new SmtTransferClearMLBuildJobFactory(SharedFileService, Engines)],
public IClearMLQueueService ClearMLMonitorService { get; }

public ISharedFileService SharedFileService { get; }
public IOptionsMonitor<SmtTransferEngineOptions> EngineOptions { get; }

public IBuildJobService BuildJobService { get; }
public Func<Task> TrainJobFunc { get; set; }
Expand All @@ -370,8 +366,8 @@ public async Task CommitAsync(TimeSpan inactiveTimeout)

public void StopServer()
{
StateService.Dispose();
_jobServer.Dispose();
StateService.Dispose();
}

public void StartServer()
Expand Down Expand Up @@ -580,6 +576,13 @@ public Task WaitForBuildToStartAsync()
return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Active);
}

public Task WaitForTrainingToStartAsync()
{
return WaitForBuildState(e =>
e.CurrentBuild!.JobState is BuildJobState.Active && e.CurrentBuild!.Stage is BuildStage.Train
);
}

public Task WaitForBuildToRestartAsync()
{
return WaitForBuildState(e => e.CurrentBuild!.JobState is BuildJobState.Pending);
Expand Down Expand Up @@ -614,44 +617,22 @@ private class EnvActivator(TestEnvironment env) : JobActivator
{
private readonly TestEnvironment _env = env;

public class SmtTransferPreprocessBuildJobTest : SmtTransferPreprocessBuildJob
{
public SmtTransferPreprocessBuildJobTest(
IPlatformService platformService,
IRepository<TranslationEngine> engines,
IDistributedReaderWriterLockFactory lockFactory,
ILogger<SmtTransferPreprocessBuildJob> logger,
IBuildJobService buildJobService,
ISharedFileService sharedFileService,
ICorpusService corpusService
)
: base(
platformService,
engines,
lockFactory,
logger,
buildJobService,
sharedFileService,
corpusService
)
{
TrainJobRunnerType = BuildJobRunnerType.Hangfire;
}
}

public override object ActivateJob(Type jobType)
{
if (jobType == typeof(SmtTransferPreprocessBuildJob))
{
return new SmtTransferPreprocessBuildJobTest(
return new SmtTransferPreprocessBuildJob(
_env.PlatformService,
_env.Engines,
_env._lockFactory,
Substitute.For<ILogger<SmtTransferPreprocessBuildJobTest>>(),
Substitute.For<ILogger<SmtTransferPreprocessBuildJob>>(),
_env.BuildJobService,
_env.SharedFileService,
Substitute.For<ICorpusService>()
);
)
{
TrainJobRunnerType = BuildJobRunnerType.Hangfire
};
}
if (jobType == typeof(SmtTransferPostprocessBuildJob))
{
Expand All @@ -672,7 +653,6 @@ public override object ActivateJob(Type jobType)
return new SmtTransferTrainBuildJob(
_env.PlatformService,
_env.Engines,
_env.EngineOptions,
_env._lockFactory,
_env.BuildJobService,
Substitute.For<ILogger<SmtTransferTrainBuildJob>>(),
Expand Down

0 comments on commit b6ac585

Please sign in to comment.