Skip to content

Commit

Permalink
Leverage "Retry-After" usage to LUIS API calls
Browse files Browse the repository at this point in the history
If a "Retry-After" HTTP response header is returned by LUIS, this change ensures that header is respected.

Fixes microsoft#333
  • Loading branch information
rozele committed Nov 23, 2020
1 parent 4916f57 commit 2d3848c
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 162 deletions.
2 changes: 1 addition & 1 deletion src/NLU.DevOps.Luis.Shared/ILuisTrainClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public interface ILuisTrainClient : IDisposable
/// <param name="appId">LUIS app ID.</param>
/// <param name="versionId">LUIS version ID.</param>
/// <param name="cancellationToken">Cancellation token.</param>
Task<IList<ModelTrainingInfo>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken);
Task<OperationResponse<IList<ModelTrainingInfo>>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken);

/// <summary>
/// Imports the LUIS app version.
Expand Down
51 changes: 19 additions & 32 deletions src/NLU.DevOps.Luis.Shared/LuisNLUTrainClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,6 @@ public void Dispose()
this.LuisClient.Dispose();
}

private static bool IsTransientStatusCode(HttpStatusCode statusCode)
{
return statusCode == HttpStatusCode.TooManyRequests
|| (statusCode >= HttpStatusCode.InternalServerError
&& statusCode != HttpStatusCode.HttpVersionNotSupported
&& statusCode != HttpStatusCode.NotImplemented);
}

private LuisApp CreateLuisApp(IEnumerable<ILabeledUtterance> utterances)
{
var luisApp = this.CreateLuisAppTemplate();
Expand Down Expand Up @@ -216,36 +208,31 @@ private async Task PollTrainingStatusAsync(CancellationToken cancellationToken)
{
while (true)
{
try
{
var trainingStatus = await this.LuisClient.GetTrainingStatusAsync(this.LuisAppId, this.LuisConfiguration.VersionId, cancellationToken).ConfigureAwait(false);
var inProgress = trainingStatus
.Select(modelInfo => modelInfo.Details.Status)
.Any(status => status == "InProgress" || status == "Queued");
var trainingStatus = await Retry.With(cancellationToken).OnTransientErrorResponseAsync(() =>
this.LuisClient.GetTrainingStatusAsync(this.LuisAppId, this.LuisConfiguration.VersionId, cancellationToken))
.ConfigureAwait(false);

if (!inProgress)
{
if (trainingStatus.Any(modelInfo => modelInfo.Details.Status == "Fail"))
{
var failureReasons = trainingStatus
.Where(modelInfo => modelInfo.Details.Status == "Fail")
.Select(modelInfo => $"- {modelInfo.Details.FailureReason}");
var inProgress = trainingStatus.Value
.Select(modelInfo => modelInfo.Details.Status)
.Any(status => status == "InProgress" || status == "Queued");

throw new InvalidOperationException($"Failure occurred while training LUIS model:\n{string.Join('\n', failureReasons)}");
}
if (!inProgress)
{
if (trainingStatus.Value.Any(modelInfo => modelInfo.Details.Status == "Fail"))
{
var failureReasons = trainingStatus.Value
.Where(modelInfo => modelInfo.Details.Status == "Fail")
.Select(modelInfo => $"- {modelInfo.Details.FailureReason}");

break;
throw new InvalidOperationException($"Failure occurred while training LUIS model:\n{string.Join('\n', failureReasons)}");
}

Logger.LogTrace($"Training jobs not complete. Polling again.");
await Task.Delay(TrainStatusDelay, cancellationToken).ConfigureAwait(false);
}
catch (ErrorResponseException ex)
when (IsTransientStatusCode(ex.Response.StatusCode))
{
Logger.LogTrace("Received HTTP 429 result from LUIS. Retrying.");
await Task.Delay(TrainStatusDelay, cancellationToken).ConfigureAwait(false);
break;
}

Logger.LogTrace($"Training jobs not complete. Polling again.");
var delay = Retry.GetRetryAfterDelay(trainingStatus.RetryAfter, TrainStatusDelay);
await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/NLU.DevOps.Luis.Shared/LuisTrainClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ public Task DeleteVersionAsync(string appId, string versionId, CancellationToken
return this.AuthoringClient.Versions.DeleteAsync(Guid.Parse(appId), versionId, cancellationToken);
}

public Task<IList<ModelTrainingInfo>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken)
public async Task<OperationResponse<IList<ModelTrainingInfo>>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken)
{
return this.AuthoringClient.Train.GetStatusAsync(Guid.Parse(appId), versionId, cancellationToken);
var operationResponse = await this.AuthoringClient.Train.GetStatusWithHttpMessagesAsync(Guid.Parse(appId), versionId, cancellationToken: cancellationToken).ConfigureAwait(false);
return OperationResponse.Create(operationResponse.Body, operationResponse.Response);
}

public Task ImportVersionAsync(string appId, string versionId, LuisApp luisApp, CancellationToken cancellationToken)
Expand Down
3 changes: 3 additions & 0 deletions src/NLU.DevOps.Luis.Shared/NLU.DevOps.Luis.Shared.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,8 @@
<Compile Include="$(MSBuildThisFileDirectory)ILuisConfiguration.cs" />
<Compile Include="$(MSBuildThisFileDirectory)TestLuisConfiguration.cs" />
<Compile Include="$(MSBuildThisFileDirectory)JSONEntityWithRole.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Retry.cs" />
<Compile Include="$(MSBuildThisFileDirectory)OperationResponse.Generic.cs" />
<Compile Include="$(MSBuildThisFileDirectory)OperationResponse.cs" />
</ItemGroup>
</Project>
28 changes: 28 additions & 0 deletions src/NLU.DevOps.Luis.Shared/OperationResponse.Generic.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace NLU.DevOps.Luis
{
/// <summary>
/// Information about the batch test evaluation operation status.
/// </summary>
/// <typeparam name="T">Type of response value.</typeparam>
public class OperationResponse<T>
{
internal OperationResponse(T value, string retryAfter)
{
this.Value = value;
this.RetryAfter = retryAfter;
}

/// <summary>
/// Gets the response value.
/// </summary>
public T Value { get; }

/// <summary>
/// Gets the HTTP 'Retry-After' header.
/// </summary>
public string RetryAfter { get; }
}
}
27 changes: 27 additions & 0 deletions src/NLU.DevOps.Luis.Shared/OperationResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace NLU.DevOps.Luis
{
using System.Linq;
using System.Net.Http;

/// <summary>
/// Factory methods for <see cref="OperationResponse{T}"/>.
/// </summary>
public static class OperationResponse
{
/// <summary>
/// Creates an instance of <see cref="OperationResponse{T}"/>.
/// </summary>
/// <typeparam name="T">Type of response value.</typeparam>
/// <param name="value">Response value.</param>
/// <param name="response">HTTP response.</param>
/// <returns>Instance of <see cref="OperationResponse{T}"/>.</returns>
public static OperationResponse<T> Create<T>(T value, HttpResponseMessage response = default)
{
var retryAfter = response?.Headers?.GetValues(Retry.RetryAfterHeader).FirstOrDefault();
return new OperationResponse<T>(value, retryAfter);
}
}
}
121 changes: 121 additions & 0 deletions src/NLU.DevOps.Luis.Shared/Retry.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

namespace NLU.DevOps.Luis
{
using System;
using System.Globalization;
using System.Linq;
using System.Net;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.CognitiveServices.Language.LUIS.Authoring.Models;
#if LUIS_V2
using ErrorException = Microsoft.Azure.CognitiveServices.Language.LUIS.Runtime.Models.APIErrorException;
#else
using Microsoft.Azure.CognitiveServices.Language.LUIS.Runtime.Models;
#endif

internal static class Retry
{
public const string RetryAfterHeader = "Retry-After";

private static readonly Regex RetryAfterSecondsRegex = new Regex(@"^\d+$");

private static TimeSpan DefaultTransientDelay { get; } = TimeSpan.FromMilliseconds(100);

public static TimeSpan GetRetryAfterDelay(string retryAfter, TimeSpan? defaultDelay = default)
{
if (retryAfter == null)
{
return defaultDelay ?? DefaultTransientDelay;
}

if (RetryAfterSecondsRegex.IsMatch(retryAfter))
{
return TimeSpan.FromSeconds(int.Parse(retryAfter, CultureInfo.InvariantCulture));
}

return DateTimeOffset.Parse(retryAfter, CultureInfo.InvariantCulture) - DateTimeOffset.Now;
}

public static CancellationTokenHolder With(CancellationToken cancellationToken)
{
return new CancellationTokenHolder(cancellationToken);
}

private static async Task<TResult> OnTransientExceptionAsync<TResult, TException>(
Func<Task<TResult>> func,
Func<TException, HttpStatusCode> statusCodeSelector,
Func<TException, string> retryAfterDelaySelector = default,
int retryCount = int.MaxValue,
CancellationToken cancellationToken = default)
where TException : Exception
{
var count = 0;
while (count++ < retryCount)
{
cancellationToken.ThrowIfCancellationRequested();

try
{
return await func().ConfigureAwait(false);
}
catch (TException ex)
when (count < retryCount && IsTransientStatusCode(statusCodeSelector(ex)))
{
var delay = GetRetryAfterDelay(retryAfterDelaySelector?.Invoke(ex));
await Task.Delay(delay, cancellationToken).ConfigureAwait(false);
}
}

throw new InvalidOperationException("Exception will be rethrown before reaching this point.");
}

private static bool IsTransientStatusCode(HttpStatusCode statusCode)
{
return statusCode == HttpStatusCode.TooManyRequests
|| (statusCode >= HttpStatusCode.InternalServerError
&& statusCode != HttpStatusCode.HttpVersionNotSupported
&& statusCode != HttpStatusCode.NotImplemented);
}

public class CancellationTokenHolder
{
public CancellationTokenHolder(CancellationToken cancellationToken)
{
this.CancellationToken = cancellationToken;
}

private CancellationToken CancellationToken { get; }

public Task<T> OnTransientErrorAsync<T>(Func<Task<T>> func)
{
return OnTransientExceptionAsync(
func,
(ErrorException ex) => ex.Response.StatusCode,
(ErrorException ex) => ex.Response.Headers?[RetryAfterHeader]?.FirstOrDefault(),
cancellationToken: this.CancellationToken);
}

public Task<T> OnTransientErrorResponseAsync<T>(Func<Task<T>> func)
{
return OnTransientExceptionAsync(
func,
(ErrorResponseException ex) => ex.Response.StatusCode,
(ErrorResponseException ex) => ex.Response.Headers?[RetryAfterHeader]?.FirstOrDefault(),
cancellationToken: this.CancellationToken);
}

public Task<T> OnTransientWebExceptionAsync<T>(Func<Task<T>> func)
{
return OnTransientExceptionAsync(
func,
(WebException ex) => (ex.Response as HttpWebResponse)?.StatusCode ?? default,
(WebException ex) => (ex.Response as HttpWebResponse)?.Headers?[RetryAfterHeader],
cancellationToken: this.CancellationToken);
}
}
}
}
31 changes: 17 additions & 14 deletions src/NLU.DevOps.Luis.Tests.Shared/LuisNLUTrainClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,14 @@ public static async Task TrainingStatusDelayBetweenPolling()
It.Is<string>(appId => appId == builder.AppId),
It.IsAny<string>(),
It.IsAny<CancellationToken>()))
.Returns(() => Task.FromResult<IList<ModelTrainingInfo>>(new[]
{
new ModelTrainingInfo
.Returns(() => Task.FromResult(
OperationResponse.Create<IList<ModelTrainingInfo>>(new[]
{
Details = new ModelTrainingDetails { Status = statusArray[count++] }
}
}))
new ModelTrainingInfo
{
Details = new ModelTrainingDetails { Status = statusArray[count++] }
}
})))
.Callback(() => timestamps[count - 1] = DateTimeOffset.Now);

using (var luis = builder.Build())
Expand Down Expand Up @@ -251,13 +252,14 @@ public static void TrainingFailedThrowsInvalidOperation()
It.Is<string>(appId => appId == builder.AppId),
It.IsAny<string>(),
It.IsAny<CancellationToken>()))
.Returns(() => Task.FromResult<IList<ModelTrainingInfo>>(new[]
{
new ModelTrainingInfo
.Returns(() => Task.FromResult(
OperationResponse.Create<IList<ModelTrainingInfo>>(new[]
{
Details = new ModelTrainingDetails { Status = "Fail", FailureReason = failureReason }
}
}));
new ModelTrainingInfo
{
Details = new ModelTrainingDetails { Status = "Fail", FailureReason = failureReason }
}
})));

using (var luis = builder.Build())
{
Expand Down Expand Up @@ -377,8 +379,9 @@ private class LuisNLUTrainClientBuilder
public LuisNLUTrainClient Build()
{
this.MockLuisTrainClient.SetReturnsDefault(
Task.FromResult<IList<ModelTrainingInfo>>(
Array.Empty<ModelTrainingInfo>()));
Task.FromResult(
OperationResponse.Create<IList<ModelTrainingInfo>>(
Array.Empty<ModelTrainingInfo>())));

var luisConfiguration = new LuisConfiguration(new ConfigurationBuilder()
.AddInMemoryCollection(new Dictionary<string, string>
Expand Down
Loading

0 comments on commit 2d3848c

Please sign in to comment.