Skip to content
This repository has been archived by the owner on Nov 1, 2023. It is now read-only.

Improved Async Support for building identity #97

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions WebApiThrottle.WebApiDemo/Helpers/CustomThrottlingHandler.cs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Web;

namespace WebApiThrottle.WebApiDemo.Helpers
{
public class CustomThrottlingHandler : ThrottlingHandler
{
protected override RequestIdentity SetIdentity(System.Net.Http.HttpRequestMessage request)
protected override Task<RequestIdentity> SetIdentityAsync(System.Net.Http.HttpRequestMessage request)
{
return new RequestIdentity()
return Task.FromResult(new RequestIdentity()
{
ClientKey = request.Headers.Contains("Authorization-Key") ? request.Headers.GetValues("Authorization-Key").First() : "anon",
ClientIp = base.GetClientIp(request).ToString(),
Endpoint = request.RequestUri.AbsolutePath.ToLowerInvariant()
};
});
}
}
}
15 changes: 11 additions & 4 deletions WebApiThrottle/ThrottlingFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public ThrottlePolicy Policy
/// </summary>
public HttpStatusCode QuotaExceededResponseCode { get; set; }

public override void OnActionExecuting(HttpActionContext actionContext)
public override async Task OnActionExecutingAsync(HttpActionContext actionContext, System.Threading.CancellationToken cancellationToken)
{
EnableThrottlingAttribute attrPolicy = null;
var applyThrottling = ApplyThrottling(actionContext, out attrPolicy);
Expand All @@ -145,7 +145,7 @@ public override void OnActionExecuting(HttpActionContext actionContext)
core.Repository = Repository;
core.Policy = Policy;

var identity = SetIdentity(actionContext.Request);
var identity = await SetIdentityAsync(actionContext.Request);

if (!core.IsWhitelisted(identity))
{
Expand Down Expand Up @@ -219,17 +219,24 @@ public override void OnActionExecuting(HttpActionContext actionContext)
}
}

base.OnActionExecuting(actionContext);
await base.OnActionExecutingAsync(actionContext, cancellationToken);
}


[Obsolete("This method is deprecated, use SetIdentityAsync instead")]
protected virtual RequestIdentity SetIdentity(HttpRequestMessage request)
{
throw new NotImplementedException("This method is deprecated, use SetIdentityAsync instead");
}

protected virtual Task<RequestIdentity> SetIdentityAsync(HttpRequestMessage request)
{
var entry = new RequestIdentity();
entry.ClientIp = core.GetClientIp(request).ToString();
entry.Endpoint = request.RequestUri.AbsolutePath.ToLowerInvariant();
entry.ClientKey = request.Headers.Contains("Authorization-Token") ? request.Headers.GetValues("Authorization-Token").First() : "anon";

return entry;
return Task.FromResult(entry);
}

protected virtual string ComputeThrottleKey(RequestIdentity requestIdentity, RateLimitPeriod period)
Expand Down
20 changes: 13 additions & 7 deletions WebApiThrottle/ThrottlingHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public ThrottlePolicy Policy
/// </summary>
public HttpStatusCode QuotaExceededResponseCode { get; set; }

protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
// get policy from repo
if (policyRepository != null)
Expand All @@ -139,17 +139,17 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques

if (policy == null || (!policy.IpThrottling && !policy.ClientThrottling && !policy.EndpointThrottling))
{
return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}

core.Repository = Repository;
core.Policy = policy;

var identity = SetIdentity(request);
var identity = await SetIdentityAsync(request);

if (core.IsWhitelisted(identity))
{
return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}

TimeSpan timeSpan = TimeSpan.FromSeconds(1);
Expand Down Expand Up @@ -204,7 +204,7 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
: string.Format(message, rateLimit, rateLimitPeriod);

// break execution
return QuotaExceededResponse(
return await QuotaExceededResponse(
request,
content,
QuotaExceededResponseCode,
Expand All @@ -214,15 +214,21 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
}

// no throttling required
return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}

protected IPAddress GetClientIp(HttpRequestMessage request)
{
return core.GetClientIp(request);
}

[Obsolete("This method is deprecated, use SetIdentityAsync instead")]
protected virtual RequestIdentity SetIdentity(HttpRequestMessage request)
{
throw new NotImplementedException("This method is deprecated, use SetIdentityAsync instead");
}

protected virtual Task<RequestIdentity> SetIdentityAsync(HttpRequestMessage request)
{
var entry = new RequestIdentity();
entry.ClientIp = core.GetClientIp(request).ToString();
Expand All @@ -231,7 +237,7 @@ protected virtual RequestIdentity SetIdentity(HttpRequestMessage request)
? request.Headers.GetValues("Authorization-Token").First()
: "anon";

return entry;
return Task.FromResult(entry);
}

protected virtual string ComputeThrottleKey(RequestIdentity requestIdentity, RateLimitPeriod period)
Expand Down
10 changes: 8 additions & 2 deletions WebApiThrottle/ThrottlingMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public override async Task Invoke(IOwinContext context)
core.Repository = Repository;
core.Policy = policy;

var identity = SetIdentity(request);
var identity = await SetIdentityAsync(request);

if (core.IsWhitelisted(identity))
{
Expand Down Expand Up @@ -213,7 +213,13 @@ public override async Task Invoke(IOwinContext context)
await Next.Invoke(context);
}

[Obsolete("This method is deprecated, use SetIdentityAsync instead")]
protected virtual RequestIdentity SetIdentity(IOwinRequest request)
{
throw new NotImplementedException("This method is deprecated, use SetIdentityAsync instead");
}

protected virtual Task<RequestIdentity> SetIdentityAsync(IOwinRequest request)
{
var entry = new RequestIdentity();
entry.ClientIp = request.RemoteIpAddress;
Expand All @@ -222,7 +228,7 @@ protected virtual RequestIdentity SetIdentity(IOwinRequest request)
? request.Headers.GetValues("Authorization-Token").First()
: "anon";

return entry;
return Task.FromResult(entry);
}

protected virtual string ComputeThrottleKey(RequestIdentity requestIdentity, RateLimitPeriod period)
Expand Down