Skip to content

Commit

Permalink
#32 Fixes for NMT cancellation
Browse files Browse the repository at this point in the history
* Fix "get task by name" from ClearML
* Use same sub for cancellation as from SMT - and refractor
* Fix tests
  • Loading branch information
johnml1135 committed Jul 12, 2023
1 parent b8e50ca commit dc2dade
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,30 +35,36 @@ public async Task RunAsync(
string engineId,
string buildId,
IReadOnlyList<Corpus> corpora,
CancellationToken cancellationToken
CancellationToken externalCancellationToken
)
{
string? clearMLProjectId = await _clearMLService.GetProjectIdAsync(engineId, cancellationToken);
string? clearMLProjectId = await _clearMLService.GetProjectIdAsync(engineId, externalCancellationToken);
if (clearMLProjectId is null)
return;

try
{
var combinedCancellationToken = new SubscribeForCancellation(_engines).GetCombinedCancellationToken(
engineId,
buildId,
externalCancellationToken
);

TranslationEngine? engine = await _engines.GetAsync(
e => e.EngineId == engineId && e.BuildId == buildId,
cancellationToken: cancellationToken
cancellationToken: combinedCancellationToken
);
if (engine is null || engine.IsCanceled)
throw new OperationCanceledException();

int corpusSize;
if (engine.BuildState is BuildState.Pending)
corpusSize = await WriteDataFilesAsync(buildId, corpora, cancellationToken);
corpusSize = await WriteDataFilesAsync(buildId, corpora, combinedCancellationToken);
else
corpusSize = GetCorpusSize(corpora);

string clearMLTaskId;
ClearMLTask? clearMLTask = await _clearMLService.GetTaskAsync(buildId, clearMLProjectId, cancellationToken);
ClearMLTask? clearMLTask = await _clearMLService.GetTaskByNameAsync(buildId, combinedCancellationToken);
if (clearMLTask is null)
{
clearMLTaskId = await _clearMLService.CreateTaskAsync(
Expand All @@ -68,7 +74,7 @@ CancellationToken cancellationToken
engine.SourceLanguage,
engine.TargetLanguage,
_sharedFileService.GetBaseUri().ToString(),
cancellationToken
combinedCancellationToken
);
await _clearMLService.EnqueueTaskAsync(clearMLTaskId, CancellationToken.None);
}
Expand All @@ -80,9 +86,9 @@ CancellationToken cancellationToken
int lastIteration = 0;
while (true)
{
cancellationToken.ThrowIfCancellationRequested();
combinedCancellationToken.ThrowIfCancellationRequested();

clearMLTask = await _clearMLService.GetTaskAsync(clearMLTaskId, cancellationToken);
clearMLTask = await _clearMLService.GetTaskByIdAsync(clearMLTaskId, combinedCancellationToken);
if (clearMLTask is null)
throw new InvalidOperationException("The ClearML task does not exist.");

Expand All @@ -98,7 +104,7 @@ or ClearMLTaskStatus.Completed
engine = await _engines.UpdateAsync(
e => e.EngineId == engineId && e.BuildId == buildId && !e.IsCanceled,
u => u.Set(e => e.BuildState, BuildState.Active),
cancellationToken: cancellationToken
cancellationToken: combinedCancellationToken
);
if (engine is null)
throw new OperationCanceledException();
Expand Down Expand Up @@ -129,11 +135,11 @@ await _engines.UpdateAsync(
}
if (clearMLTask.Status is ClearMLTaskStatus.Completed)
break;
await Task.Delay(_options.CurrentValue.BuildPollingTimeout, cancellationToken);
await Task.Delay(_options.CurrentValue.BuildPollingTimeout, combinedCancellationToken);
}

// The ClearML task has successfully completed, so insert the generated pretranslations into the database.
await InsertPretranslationsAsync(engineId, buildId, cancellationToken);
await InsertPretranslationsAsync(engineId, buildId, combinedCancellationToken);

IReadOnlyDictionary<string, double> metrics = await _clearMLService.GetTaskMetricsAsync(
clearMLTaskId,
Expand Down Expand Up @@ -174,11 +180,7 @@ await _platformService.BuildCompletedAsync(
if (engine is null || engine.IsCanceled)
{
// This is an actual cancellation triggered by an API call.
ClearMLTask? task = await _clearMLService.GetTaskAsync(
buildId,
clearMLProjectId,
CancellationToken.None
);
ClearMLTask? task = await _clearMLService.GetTaskByNameAsync(buildId, CancellationToken.None);
if (task is not null)
await _clearMLService.StopTaskAsync(task.Id, CancellationToken.None);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ protected override Expression<Func<ClearMLNmtEngineBuildJob, Task>> GetJobExpres
IReadOnlyList<Corpus> corpora
)
{
// Token "None" is used here because hangfire injects the proper cancellation token
return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None);
}
}
16 changes: 4 additions & 12 deletions src/Machine/src/SIL.Machine.AspNetCore/Services/ClearMLService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,12 @@ public async Task<bool> StopTaskAsync(string id, CancellationToken cancellationT
return updated == 1;
}

public Task<ClearMLTask?> GetTaskAsync(string name, string projectId, CancellationToken cancellationToken = default)
public Task<ClearMLTask?> GetTaskByNameAsync(string name, CancellationToken cancellationToken = default)
{
return GetTaskAsync(
new JsonObject
{
["id"] = new JsonArray(),
["name"] = name,
["project"] = new JsonArray(projectId)
},
cancellationToken
);
return GetTaskAsync(new JsonObject { ["name"] = name }, cancellationToken);
}

public Task<ClearMLTask?> GetTaskAsync(string id, CancellationToken cancellationToken = default)
public Task<ClearMLTask?> GetTaskByIdAsync(string id, CancellationToken cancellationToken = default)
{
return GetTaskAsync(new JsonObject { ["id"] = id }, cancellationToken);
}
Expand Down Expand Up @@ -206,7 +198,7 @@ public async Task<IReadOnlyDictionary<string, double>> GetTaskMetricsAsync(
"status_reason",
"active_duration"
);
JsonObject? result = await CallAsync("tasks", "get_by_id_ex", body, cancellationToken);
JsonObject? result = await CallAsync("tasks", "get_all_ex", body, cancellationToken);
var tasks = (JsonArray?)result?["data"]?["tasks"];
if (tasks is null || tasks.Count == 0)
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ Task<string> CreateTaskAsync(
Task<bool> EnqueueTaskAsync(string id, CancellationToken cancellationToken = default);
Task<bool> DequeueTaskAsync(string id, CancellationToken cancellationToken = default);
Task<bool> StopTaskAsync(string id, CancellationToken cancellationToken = default);
Task<ClearMLTask?> GetTaskAsync(string name, string projectId, CancellationToken cancellationToken = default);
Task<ClearMLTask?> GetTaskAsync(string id, CancellationToken cancellationToken = default);
Task<ClearMLTask?> GetTaskByNameAsync(string name, CancellationToken cancellationToken = default);
Task<ClearMLTask?> GetTaskByIdAsync(string id, CancellationToken cancellationToken = default);
Task<IReadOnlyDictionary<string, double>> GetTaskMetricsAsync(
string id,
CancellationToken cancellationToken = default
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,11 @@ CancellationToken externalCancellationToken
ITrainer? truecaseTrainer = null;
try
{
CancellationTokenSource cts = new();
SubscribeForCancellationAsync(cts, engineId, buildId);
CancellationTokenSource combinedCancellationSource = CancellationTokenSource.CreateLinkedTokenSource(
externalCancellationToken,
cts.Token
var combinedCancellationToken = new SubscribeForCancellation(_engines).GetCombinedCancellationToken(
engineId,
buildId,
externalCancellationToken
);
var combinedCancellationToken = combinedCancellationSource.Token;

var stopwatch = new Stopwatch();
TranslationEngine? engine;
Expand Down Expand Up @@ -226,27 +224,4 @@ await _engines.UpdateAsync(
truecaseTrainer?.Dispose();
}
}

private async void SubscribeForCancellationAsync(CancellationTokenSource cts, string engineId, string buildId)
{
var cancellationToken = cts.Token;
ISubscription<TranslationEngine> sub = await _engines.SubscribeAsync(
e => e.EngineId == engineId && e.BuildId == buildId
);
if (sub.Change.Entity is null)
return;
while (true)
{
await sub.WaitForChangeAsync(TimeSpan.FromSeconds(10), cancellationToken);
TranslationEngine? engine = sub.Change.Entity;
if (engine is null || engine.IsCanceled)
{
cts.Cancel();
return;
}
if (cancellationToken.IsCancellationRequested)
return;
Thread.Sleep(500);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ protected override Expression<Func<SmtTransferEngineBuildJob, Task>> GetJobExpre
IReadOnlyList<Corpus> corpora
)
{
// Token "None" is used here because hangfire injects the proper cancellation token
return r => r.RunAsync(engineId, buildId, corpora, CancellationToken.None);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
namespace SIL.Machine.AspNetCore.Services;

public class SubscribeForCancellation
{
private readonly IRepository<TranslationEngine> _engines;

public SubscribeForCancellation(IRepository<TranslationEngine> engines)
{
_engines = engines;
}

public CancellationToken GetCombinedCancellationToken(
string engineId,
string buildId,
CancellationToken externalCancellationToken
)
{
CancellationTokenSource cts = new();
SubscribeForCancellationAsync(cts, engineId, buildId);
CancellationTokenSource combinedCancellationSource = CancellationTokenSource.CreateLinkedTokenSource(
externalCancellationToken,
cts.Token
);
return combinedCancellationSource.Token;
}

private async void SubscribeForCancellationAsync(CancellationTokenSource cts, string engineId, string buildId)
{
var cancellationToken = cts.Token;
ISubscription<TranslationEngine> sub = await _engines.SubscribeAsync(
e => e.EngineId == engineId && e.BuildId == buildId
);
if (sub.Change.Entity is null)
return;
while (true)
{
await sub.WaitForChangeAsync(TimeSpan.FromSeconds(10), cancellationToken);
TranslationEngine? engine = sub.Change.Entity;
if (engine is null || engine.IsCanceled)
{
cts.Cancel();
return;
}
if (cancellationToken.IsCancellationRequested)
return;
Thread.Sleep(500);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public async Task CancelBuildAsync()
};
bool first = true;
env.ClearMLService
.GetTaskAsync(Arg.Any<string>(), "project1", Arg.Any<CancellationToken>())
.GetTaskByNameAsync(Arg.Any<string>(), Arg.Any<CancellationToken>())
.Returns(x =>
{
if (first)
Expand All @@ -37,7 +37,7 @@ public async Task CancelBuildAsync()
return Task.FromResult<ClearMLTask?>(task);
});
env.ClearMLService
.GetTaskAsync("task1", Arg.Any<CancellationToken>())
.GetTaskByIdAsync("task1", Arg.Any<CancellationToken>())
.Returns(Task.FromResult<ClearMLTask?>(task));
await env.Service.StartBuildAsync("engine1", "build1", Array.Empty<Corpus>());
await env.WaitForBuildToStartAsync();
Expand Down

0 comments on commit dc2dade

Please sign in to comment.