Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
always invoke profile service from default claims service; add loggin…
Browse files Browse the repository at this point in the history
…g to profile impls
  • Loading branch information
brockallen committed Jan 15, 2017
1 parent 361e984 commit f5b5d2e
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ public ProfileDataRequestContext()
/// <param name="requestedClaimTypes">The requested claim types.</param>
public ProfileDataRequestContext(ClaimsPrincipal subject, Client client, string caller, IEnumerable<string> requestedClaimTypes)
{
if (requestedClaimTypes.IsNullOrEmpty()) throw new ArgumentException("No claim types requested", nameof(requestedClaimTypes));
if (subject == null) throw new ArgumentNullException(nameof(subject));
if (client == null) throw new ArgumentNullException(nameof(client));
if (caller == null) throw new ArgumentNullException(nameof(caller));
if (requestedClaimTypes == null) throw new ArgumentNullException(nameof(requestedClaimTypes));

Subject = subject;
Client = client;
Expand Down
74 changes: 42 additions & 32 deletions src/IdentityServer4/Services/DefaultClaimsService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,31 @@ public virtual async Task<IEnumerable<Claim>> GetIdentityTokenClaimsAsync(Claims
// fetch all identity claims that need to go into the id token
if (includeAllIdentityClaims || client.AlwaysIncludeUserClaimsInIdToken)
{
var additionalClaims = new List<string>();
var additionalClaimTypes = new List<string>();

foreach (var identityResource in resources.IdentityResources)
{
foreach (var userClaim in identityResource.UserClaims)
{
additionalClaims.Add(userClaim);
additionalClaimTypes.Add(userClaim);
}
}

if (additionalClaims.Count > 0)
{
var context = new ProfileDataRequestContext(
subject,
client,
IdentityServerConstants.ProfileDataCallers.ClaimsProviderIdentityToken,
additionalClaims);
// filter so we don't ask for claim types that we will eventually filter out
additionalClaimTypes = FilterRequestedClaimTypes(additionalClaimTypes).ToList();

await Profile.GetProfileDataAsync(context);
var context = new ProfileDataRequestContext(
subject,
client,
IdentityServerConstants.ProfileDataCallers.ClaimsProviderIdentityToken,
additionalClaimTypes);

var claims = FilterProtocolClaims(context.IssuedClaims);
if (claims != null)
{
outputClaims.AddRange(claims);
}
await Profile.GetProfileDataAsync(context);

var claims = FilterProtocolClaims(context.IssuedClaims);
if (claims != null)
{
outputClaims.AddRange(claims);
}
}

Expand Down Expand Up @@ -155,15 +155,15 @@ public virtual async Task<IEnumerable<Claim>> GetAccessTokenClaimsAsync(ClaimsPr
outputClaims.AddRange(GetOptionalClaims(subject));

// fetch all resource claims that need to go into the access token
var additionalClaims = new List<string>();
var additionalClaimTypes = new List<string>();
foreach (var api in resources.ApiResources)
{
// add claims configured on api resource
if (api.UserClaims != null)
{
foreach (var claim in api.UserClaims)
{
additionalClaims.Add(claim);
additionalClaimTypes.Add(claim);
}
}

Expand All @@ -174,27 +174,27 @@ public virtual async Task<IEnumerable<Claim>> GetAccessTokenClaimsAsync(ClaimsPr
{
foreach (var claim in scope.UserClaims)
{
additionalClaims.Add(claim);
additionalClaimTypes.Add(claim);
}
}
}
}

if (additionalClaims.Count > 0)
{
var context = new ProfileDataRequestContext(
subject,
client,
IdentityServerConstants.ProfileDataCallers.ClaimsProviderAccessToken,
additionalClaims.Distinct());
// filter so we don't ask for claim types that we will eventually filter out
additionalClaimTypes = FilterRequestedClaimTypes(additionalClaimTypes).ToList();

await Profile.GetProfileDataAsync(context);
var context = new ProfileDataRequestContext(
subject,
client,
IdentityServerConstants.ProfileDataCallers.ClaimsProviderAccessToken,
additionalClaimTypes.Distinct());

var claims = FilterProtocolClaims(context.IssuedClaims);
if (claims != null)
{
outputClaims.AddRange(claims);
}
await Profile.GetProfileDataAsync(context);

var claims = FilterProtocolClaims(context.IssuedClaims);
if (claims != null)
{
outputClaims.AddRange(claims);
}
}

Expand Down Expand Up @@ -246,9 +246,19 @@ protected virtual IEnumerable<Claim> FilterProtocolClaims(IEnumerable<Claim> cla
if (claimsToFilter.Any())
{
var types = claimsToFilter.Select(x => x.Type);
_logger.LogInformation("Claim types from profile service that were filtered: {claimTypes}", types);
_logger.LogDebug("Claim types from profile service that were filtered: {claimTypes}", types);
}
return claims.Except(claimsToFilter);
}

/// <summary>
/// Filters out protocol claims like amr, nonce etc..
/// </summary>
/// <param name="claimTypes">The claim types.</param>
protected virtual IEnumerable<string> FilterRequestedClaimTypes(IEnumerable<string> claimTypes)
{
var claimTypesToFilter = claimTypes.Where(x => Constants.Filters.ClaimsServiceFilterClaimTypes.Contains(x));
return claimTypes.Except(claimTypesToFilter);
}
}
}
22 changes: 21 additions & 1 deletion src/IdentityServer4/Services/DefaultProfileService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

using System.Threading.Tasks;
using IdentityServer4.Models;
using IdentityServer4.Extensions;
using Microsoft.Extensions.Logging;
using System.Linq;

namespace IdentityServer4.Services
{
Expand All @@ -13,14 +16,31 @@ namespace IdentityServer4.Services
/// <seealso cref="IdentityServer4.Services.IProfileService" />
public class DefaultProfileService : IProfileService
{
private readonly ILogger<DefaultProfileService> _logger;

public DefaultProfileService(ILogger<DefaultProfileService> logger)
{
_logger = logger;
}

/// <summary>
/// This method is called whenever claims about the user are requested (e.g. during token creation or via the userinfo endpoint)
/// </summary>
/// <param name="context">The context.</param>
/// <returns></returns>
public Task GetProfileDataAsync(ProfileDataRequestContext context)
{
context.AddFilteredClaims(context.Subject.Claims);
_logger.LogDebug("Get profile called for {subject} from {client} with {claimTypes} because {caller}",
context.Subject.GetSubjectId(),
context.Client.ClientName,
context.RequestedClaimTypes,
context.Caller);

if (context.RequestedClaimTypes.Any())
{
context.AddFilteredClaims(context.Subject.Claims);
}

return Task.FromResult(0);
}

Expand Down
20 changes: 16 additions & 4 deletions src/IdentityServer4/Test/TestUserProfileService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,36 @@
using IdentityServer4.Extensions;
using IdentityServer4.Models;
using IdentityServer4.Services;
using Microsoft.Extensions.Logging;
using System.Linq;
using System.Threading.Tasks;

namespace IdentityServer4.Test
{
public class TestUserProfileService : IProfileService
{
private readonly ILogger<TestUserProfileService> _logger;
private readonly TestUserStore _users;

public TestUserProfileService(TestUserStore users)
public TestUserProfileService(TestUserStore users, ILogger<TestUserProfileService> logger)
{
_users = users;
_logger = logger;
}

public Task GetProfileDataAsync(ProfileDataRequestContext context)
{
var user = _users.FindBySubjectId(context.Subject.GetSubjectId());

context.AddFilteredClaims(user.Claims);
_logger.LogDebug("Get profile called for {subject} from {client} with {claimTypes} because {caller}",
context.Subject.GetSubjectId(),
context.Client.ClientName,
context.RequestedClaimTypes,
context.Caller);

if (context.RequestedClaimTypes.Any())
{
var user = _users.FindBySubjectId(context.Subject.GetSubjectId());
context.AddFilteredClaims(user.Claims);
}

return Task.FromResult(0);
}
Expand Down

0 comments on commit f5b5d2e

Please sign in to comment.