Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leverage "Retry-After" usage to LUIS API calls #347

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/NLU.DevOps.Luis.Shared/ILuisTrainClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public interface ILuisTrainClient : IDisposable
/// <param name="appId">LUIS app ID.</param>
/// <param name="versionId">LUIS version ID.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task<IList<ModelTrainingInfo>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken);
Task<OperationResponse<IList<ModelTrainingInfo>>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken);

/// <summary>
/// Imports the LUIS app version.
Expand Down
51 changes: 19 additions & 32 deletions src/NLU.DevOps.Luis.Shared/LuisNLUTrainClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,6 @@ public void Dispose()
this.LuisClient.Dispose();
}

private static bool IsTransientStatusCode(HttpStatusCode statusCode)
{
return statusCode == HttpStatusCode.TooManyRequests
|| (statusCode >= HttpStatusCode.InternalServerError
&& statusCode != HttpStatusCode.HttpVersionNotSupported
&& statusCode != HttpStatusCode.NotImplemented);
}

private LuisApp CreateLuisApp(IEnumerable<ILabeledUtterance> utterances)
{
var luisApp = this.CreateLuisAppTemplate();
Expand Down Expand Up @@ -216,36 +208,31 @@ private async Task PollTrainingStatusAsync(CancellationToken cancellationToken)
{
while (true)
{
try
{
var trainingStatus = await this.LuisClient.GetTrainingStatusAsync(this.LuisAppId, this.LuisConfiguration.VersionId, cancellationToken).ConfigureAwait(false);
var inProgress = trainingStatus
.Select(modelInfo => modelInfo.Details.Status)
.Any(status => status == "InProgress" || status == "Queued");
var trainingStatus = await Retry.With(cancellationToken).OnTransientErrorResponseAsync(() =>
this.LuisClient.GetTrainingStatusAsync(this.LuisAppId, this.LuisConfiguration.VersionId, cancellationToken))
.ConfigureAwait(false);

if (!inProgress)
{
if (trainingStatus.Any(modelInfo => modelInfo.Details.Status == "Fail"))
{
var failureReasons = trainingStatus
.Where(modelInfo => modelInfo.Details.Status == "Fail")
.Select(modelInfo => $"- {modelInfo.Details.FailureReason}");
var inProgress = trainingStatus.Value
.Select(modelInfo => modelInfo.Details.Status)
.Any(status => status == "InProgress" || status == "Queued");

throw new InvalidOperationException($"Failure occurred while training LUIS model:\n{string.Join('\n', failureReasons)}");
}
if (!inProgress)
{
if (trainingStatus.Value.Any(modelInfo => modelInfo.Details.Status == "Fail"))
{
var failureReasons = trainingStatus.Value
.Where(modelInfo => modelInfo.Details.Status == "Fail")
.Select(modelInfo => $"- {modelInfo.Details.FailureReason}");

break;
throw new InvalidOperationException($"Failure occurred while training LUIS model:\n{string.Join('\n', failureReasons)}");
}

Logger.LogTrace($"Training jobs not complete. Polling again.");
await Task.Delay(TrainStatusDelay, cancellationToken).ConfigureAwait(false);
}
catch (ErrorResponseException ex)
when (IsTransientStatusCode(ex.Response.StatusCode))
{
Logger.LogTrace("Received HTTP 429 result from LUIS. Retrying.");
await Task.Delay(TrainStatusDelay, cancellationToken).ConfigureAwait(false);
break;
}

Logger.LogTrace($"Training jobs not complete. Polling again.");
var delay = Retry.GetRetryAfterDelay(trainingStatus.RetryAfter, TrainStatusDelay);
await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/NLU.DevOps.Luis.Shared/LuisTrainClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ public Task DeleteVersionAsync(string appId, string versionId, CancellationToken
return this.AuthoringClient.Versions.DeleteAsync(Guid.Parse(appId), versionId, cancellationToken);
}

public Task<IList<ModelTrainingInfo>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken)
public async Task<OperationResponse<IList<ModelTrainingInfo>>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken)
{
return this.AuthoringClient.Train.GetStatusAsync(Guid.Parse(appId), versionId, cancellationToken);
var operationResponse = await this.AuthoringClient.Train.GetStatusWithHttpMessagesAsync(Guid.Parse(appId), versionId, cancellationToken: cancellationToken).ConfigureAwait(false);
return OperationResponse.Create(operationResponse.Body, operationResponse.Response);
}

public Task ImportVersionAsync(string appId, string versionId, LuisApp luisApp, CancellationToken cancellationToken)
Expand Down
3 changes: 3 additions & 0 deletions src/NLU.DevOps.Luis.Shared/NLU.DevOps.Luis.Shared.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,8 @@
<Compile Include="$(MSBuildThisFileDirectory)ILuisConfiguration.cs" />
<Compile Include="$(MSBuildThisFileDirectory)TestLuisConfiguration.cs" />
<Compile Include="$(MSBuildThisFileDirectory)JSONEntityWithRole.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Retry.cs" />
<Compile Include="$(MSBuildThisFileDirectory)OperationResponse.Generic.cs" />
<Compile Include="$(MSBuildThisFileDirectory)OperationResponse.cs" />
</ItemGroup>
</Project>
28 changes: 28 additions & 0 deletions src/NLU.DevOps.Luis.Shared/OperationResponse.Generic.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace NLU.DevOps.Luis
{
/// <summary>
/// Information about the batch test evaluation operation status.
/// </summary>
/// <typeparam name="T">Type of response value.</typeparam>
public class OperationResponse<T>
{
internal OperationResponse(T value, string retryAfter)
{
this.Value = value;
this.RetryAfter = retryAfter;
}

/// <summary>
/// Gets the response value.
/// </summary>
public T Value { get; }

/// <summary>
/// Gets the HTTP 'Retry-After' header.
/// </summary>
public string RetryAfter { get; }
}
}
27 changes: 27 additions & 0 deletions src/NLU.DevOps.Luis.Shared/OperationResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace NLU.DevOps.Luis
{
using System.Linq;
using System.Net.Http;

/// <summary>
/// Factory methods for <see cref="OperationResponse{T}"/>.
/// </summary>
public static class OperationResponse
{
/// <summary>
/// Creates an instance of <see cref="OperationResponse{T}"/>.
/// </summary>
/// <typeparam name="T">Type of response value.</typeparam>
/// <param name="value">Response value.</param>
/// <param name="response">HTTP response.</param>
/// <returns>Instance of <see cref="OperationResponse{T}"/>.</returns>
public static OperationResponse<T> Create<T>(T value, HttpResponseMessage response = default)
{
var retryAfter = response?.Headers?.GetValues(Retry.RetryAfterHeader).FirstOrDefault();
return new OperationResponse<T>(value, retryAfter);
}
}
}
121 changes: 121 additions & 0 deletions src/NLU.DevOps.Luis.Shared/Retry.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace NLU.DevOps.Luis
{
using System;
using System.Globalization;
using System.Linq;
using System.Net;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.CognitiveServices.Language.LUIS.Authoring.Models;
#if LUIS_V2
using ErrorException = Microsoft.Azure.CognitiveServices.Language.LUIS.Runtime.Models.APIErrorException;
#else
using Microsoft.Azure.CognitiveServices.Language.LUIS.Runtime.Models;
#endif

internal static class Retry
{
public const string RetryAfterHeader = "Retry-After";

private static readonly Regex RetryAfterSecondsRegex = new Regex(@"^\d+$");

private static TimeSpan DefaultTransientDelay { get; } = TimeSpan.FromMilliseconds(100);

public static TimeSpan GetRetryAfterDelay(string retryAfter, TimeSpan? defaultDelay = default)
{
if (retryAfter == null)
{
return defaultDelay ?? DefaultTransientDelay;
}

if (RetryAfterSecondsRegex.IsMatch(retryAfter))
{
return TimeSpan.FromSeconds(int.Parse(retryAfter, CultureInfo.InvariantCulture));
}

return DateTimeOffset.Parse(retryAfter, CultureInfo.InvariantCulture) - DateTimeOffset.Now;
}

public static CancellationTokenHolder With(CancellationToken cancellationToken)
{
return new CancellationTokenHolder(cancellationToken);
}

private static async Task<TResult> OnTransientExceptionAsync<TResult, TException>(
Func<Task<TResult>> func,
Func<TException, HttpStatusCode> statusCodeSelector,
Func<TException, string> retryAfterDelaySelector = default,
int retryCount = int.MaxValue,
CancellationToken cancellationToken = default)
where TException : Exception
{
var count = 0;
while (count++ < retryCount)
{
cancellationToken.ThrowIfCancellationRequested();

try
{
return await func().ConfigureAwait(false);
}
catch (TException ex)
when (count < retryCount && IsTransientStatusCode(statusCodeSelector(ex)))
{
var delay = GetRetryAfterDelay(retryAfterDelaySelector?.Invoke(ex));
await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
}
}

throw new InvalidOperationException("Exception will be rethrown before reaching this point.");
}

private static bool IsTransientStatusCode(HttpStatusCode statusCode)
{
return statusCode == HttpStatusCode.TooManyRequests
|| (statusCode >= HttpStatusCode.InternalServerError
&& statusCode != HttpStatusCode.HttpVersionNotSupported
&& statusCode != HttpStatusCode.NotImplemented);
}

public class CancellationTokenHolder
{
public CancellationTokenHolder(CancellationToken cancellationToken)
{
this.CancellationToken = cancellationToken;
}

private CancellationToken CancellationToken { get; }

public Task<T> OnTransientErrorAsync<T>(Func<Task<T>> func)
{
return OnTransientExceptionAsync(
func,
(ErrorException ex) => ex.Response.StatusCode,
(ErrorException ex) => ex.Response.Headers?[RetryAfterHeader]?.FirstOrDefault(),
cancellationToken: this.CancellationToken);
}

public Task<T> OnTransientErrorResponseAsync<T>(Func<Task<T>> func)
{
return OnTransientExceptionAsync(
func,
(ErrorResponseException ex) => ex.Response.StatusCode,
(ErrorResponseException ex) => ex.Response.Headers?[RetryAfterHeader]?.FirstOrDefault(),
cancellationToken: this.CancellationToken);
}

public Task<T> OnTransientWebExceptionAsync<T>(Func<Task<T>> func)
{
return OnTransientExceptionAsync(
func,
(WebException ex) => (ex.Response as HttpWebResponse)?.StatusCode ?? default,
(WebException ex) => (ex.Response as HttpWebResponse)?.Headers?[RetryAfterHeader],
cancellationToken: this.CancellationToken);
}
}
}
}
31 changes: 17 additions & 14 deletions src/NLU.DevOps.Luis.Tests.Shared/LuisNLUTrainClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,14 @@ public static async Task TrainingStatusDelayBetweenPolling()
It.Is<string>(appId => appId == builder.AppId),
It.IsAny<string>(),
It.IsAny<CancellationToken>()))
.Returns(() => Task.FromResult<IList<ModelTrainingInfo>>(new[]
{
new ModelTrainingInfo
.Returns(() => Task.FromResult(
OperationResponse.Create<IList<ModelTrainingInfo>>(new[]
{
Details = new ModelTrainingDetails { Status = statusArray[count++] }
}
}))
new ModelTrainingInfo
{
Details = new ModelTrainingDetails { Status = statusArray[count++] }
}
})))
.Callback(() => timestamps[count - 1] = DateTimeOffset.Now);

using (var luis = builder.Build())
Expand Down Expand Up @@ -251,13 +252,14 @@ public static void TrainingFailedThrowsInvalidOperation()
It.Is<string>(appId => appId == builder.AppId),
It.IsAny<string>(),
It.IsAny<CancellationToken>()))
.Returns(() => Task.FromResult<IList<ModelTrainingInfo>>(new[]
{
new ModelTrainingInfo
.Returns(() => Task.FromResult(
OperationResponse.Create<IList<ModelTrainingInfo>>(new[]
{
Details = new ModelTrainingDetails { Status = "Fail", FailureReason = failureReason }
}
}));
new ModelTrainingInfo
{
Details = new ModelTrainingDetails { Status = "Fail", FailureReason = failureReason }
}
})));

using (var luis = builder.Build())
{
Expand Down Expand Up @@ -377,8 +379,9 @@ private class LuisNLUTrainClientBuilder
public LuisNLUTrainClient Build()
{
this.MockLuisTrainClient.SetReturnsDefault(
Task.FromResult<IList<ModelTrainingInfo>>(
Array.Empty<ModelTrainingInfo>()));
Task.FromResult(
OperationResponse.Create<IList<ModelTrainingInfo>>(
Array.Empty<ModelTrainingInfo>())));

var luisConfiguration = new LuisConfiguration(new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string>
Expand Down
Loading