From 89c2786c9a3bb17cd77957987c042ef161ec9235 Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Fri, 12 Jan 2024 18:28:38 -0500 Subject: [PATCH] Passed shared file folder to ClearML NMT job --- .../Services/NmtClearMLBuildJobFactory.cs | 6 +- .../NmtClearMLBuildJobFactoryTests.cs | 94 +++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs diff --git a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs index 33e147044..22da69721 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs @@ -37,6 +37,9 @@ public async Task CreateJobScriptAsync( if (engine is null) throw new InvalidOperationException("The engine does not exist."); + Uri sharedFileUri = _sharedFileService.GetBaseUri(); + string baseUri = sharedFileUri.GetComponents(UriComponents.SchemeAndServer, UriFormat.Unescaped); + string folder = sharedFileUri.GetComponents(UriComponents.Path, UriFormat.Unescaped); return "from machine.jobs.build_nmt_engine import run\n" + "args = {\n" + $" 'model_type': '{_options.CurrentValue.ModelType}',\n" @@ -44,7 +47,8 @@ public async Task CreateJobScriptAsync( + $" 'build_id': '{buildId}',\n" + $" 'src_lang': '{_languageTagService.ConvertToFlores200Code(engine.SourceLanguage)}',\n" + $" 'trg_lang': '{_languageTagService.ConvertToFlores200Code(engine.TargetLanguage)}',\n" - + $" 'shared_file_uri': '{_sharedFileService.GetBaseUri()}',\n" + + $" 'shared_file_uri': '{baseUri}',\n" + + $" 'shared_file_folder': '{folder}',\n" + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") + $" 'clearml': True\n" + "}\n" diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs new file mode 100644 index 000000000..34b10e5da --- /dev/null +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtClearMLBuildJobFactoryTests.cs @@ -0,0 +1,94 @@ +namespace SIL.Machine.AspNetCore.Services; + +[TestFixture] +public class NmtClearMLBuildJobFactoryTests +{ + [Test] + public async Task CreateJobScriptAsync_BuildOptions() + { + var env = new TestEnvironment(); + string script = await env.BuildJobFactory.CreateJobScriptAsync( + "engine1", + "build1", + NmtBuildStages.Train, + buildOptions: "{ \"max_steps\": \"10\" }" + ); + Assert.That( + script, + Is.EqualTo( + @"from machine.jobs.build_nmt_engine import run +args = { + 'model_type': 'test_model', + 'engine_id': 'engine1', + 'build_id': 'build1', + 'src_lang': 'spa_Latn', + 'trg_lang': 'eng_Latn', + 'shared_file_uri': 's3://bucket', + 'shared_file_folder': 'folder1/folder2', + 'build_options': '''{ ""max_steps"": ""10"" }''', + 'clearml': True +} +run(args) +".ReplaceLineEndings("\n") + ) + ); + } + + [Test] + public async Task CreateJobScriptAsync_NoBuildOptions() + { + var env = new TestEnvironment(); + string script = await env.BuildJobFactory.CreateJobScriptAsync("engine1", "build1", NmtBuildStages.Train); + Assert.That( + script, + Is.EqualTo( + @"from machine.jobs.build_nmt_engine import run +args = { + 'model_type': 'test_model', + 'engine_id': 'engine1', + 'build_id': 'build1', + 'src_lang': 'spa_Latn', + 'trg_lang': 'eng_Latn', + 'shared_file_uri': 's3://bucket', + 'shared_file_folder': 'folder1/folder2', + 'clearml': True +} +run(args) +".ReplaceLineEndings("\n") + ) + ); + } + + private class TestEnvironment + { + public ISharedFileService SharedFileService { get; } + public MemoryRepository Engines { get; } + public IOptionsMonitor Options { get; } + public ILanguageTagService LanguageTagService { get; } + public NmtClearMLBuildJobFactory BuildJobFactory { get; } + + public TestEnvironment() + { + Engines = new MemoryRepository(); + Engines.Add( + new TranslationEngine + { + Id = "engine1", + EngineId = "engine1", + SourceLanguage = "es", + TargetLanguage = "en", + BuildRevision = 1, + CurrentBuild = new Build { BuildId = "build1", JobState = BuildJobState.Pending } + } + ); + Options = Substitute.For>(); + Options.CurrentValue.Returns(new ClearMLOptions { ModelType = "test_model" }); + SharedFileService = Substitute.For(); + SharedFileService.GetBaseUri().Returns(new Uri("s3://bucket/folder1/folder2")); + LanguageTagService = Substitute.For(); + LanguageTagService.ConvertToFlores200Code("es").Returns("spa_Latn"); + LanguageTagService.ConvertToFlores200Code("en").Returns("eng_Latn"); + BuildJobFactory = new NmtClearMLBuildJobFactory(SharedFileService, LanguageTagService, Engines, Options); + } + } +}