Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

Commit

Permalink
validate filter values on db results (#4618)
Browse files Browse the repository at this point in the history
  • Loading branch information
brockallen authored Jul 5, 2020
1 parent 3d370dc commit 94a3115
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 30 deletions.
8 changes: 4 additions & 4 deletions src/EntityFramework.Storage/src/Stores/ClientStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -52,10 +52,10 @@ public ClientStore(IConfigurationDbContext context, ILogger<ClientStore> logger)
public virtual async Task<Client> FindClientByIdAsync(string clientId)
{
IQueryable<Entities.Client> baseQuery = Context.Clients
.Where(x => x.ClientId == clientId)
.Take(1);
.Where(x => x.ClientId == clientId);

var client = await baseQuery.FirstOrDefaultAsync();
var client = (await baseQuery.ToArrayAsync())
.SingleOrDefault(x => x.ClientId == clientId);
if (client == null) return null;

await baseQuery.Include(x => x.AllowedCorsOrigins).SelectMany(c => c.AllowedCorsOrigins).LoadAsync();
Expand Down
14 changes: 9 additions & 5 deletions src/EntityFramework.Storage/src/Stores/DeviceFlowStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -74,7 +74,8 @@ public virtual async Task StoreDeviceAuthorizationAsync(string deviceCode, strin
/// <returns></returns>
public virtual async Task<DeviceCode> FindByUserCodeAsync(string userCode)
{
var deviceFlowCodes = await Context.DeviceFlowCodes.AsNoTracking().FirstOrDefaultAsync(x => x.UserCode == userCode);
var deviceFlowCodes = (await Context.DeviceFlowCodes.AsNoTracking().Where(x => x.UserCode == userCode).ToArrayAsync())
.SingleOrDefault(x => x.UserCode == userCode);
var model = ToModel(deviceFlowCodes?.Data);

Logger.LogDebug("{userCode} found in database: {userCodeFound}", userCode, model != null);
Expand All @@ -89,7 +90,8 @@ public virtual async Task<DeviceCode> FindByUserCodeAsync(string userCode)
/// <returns></returns>
public virtual async Task<DeviceCode> FindByDeviceCodeAsync(string deviceCode)
{
var deviceFlowCodes = await Context.DeviceFlowCodes.AsNoTracking().FirstOrDefaultAsync(x => x.DeviceCode == deviceCode);
var deviceFlowCodes = (await Context.DeviceFlowCodes.AsNoTracking().Where(x => x.DeviceCode == deviceCode).ToArrayAsync())
.SingleOrDefault(x => x.DeviceCode == deviceCode);
var model = ToModel(deviceFlowCodes?.Data);

Logger.LogDebug("{deviceCode} found in database: {deviceCodeFound}", deviceCode, model != null);
Expand All @@ -105,7 +107,8 @@ public virtual async Task<DeviceCode> FindByDeviceCodeAsync(string deviceCode)
/// <returns></returns>
public virtual async Task UpdateByUserCodeAsync(string userCode, DeviceCode data)
{
var existing = await Context.DeviceFlowCodes.SingleOrDefaultAsync(x => x.UserCode == userCode);
var existing = (await Context.DeviceFlowCodes.Where(x => x.UserCode == userCode).ToArrayAsync())
.SingleOrDefault(x => x.UserCode == userCode);
if (existing == null)
{
Logger.LogError("{userCode} not found in database", userCode);
Expand Down Expand Up @@ -135,7 +138,8 @@ public virtual async Task UpdateByUserCodeAsync(string userCode, DeviceCode data
/// <returns></returns>
public virtual async Task RemoveByDeviceCodeAsync(string deviceCode)
{
var deviceFlowCodes = await Context.DeviceFlowCodes.FirstOrDefaultAsync(x => x.DeviceCode == deviceCode);
var deviceFlowCodes = (await Context.DeviceFlowCodes.Where(x => x.DeviceCode == deviceCode).ToArrayAsync())
.SingleOrDefault(x => x.DeviceCode == deviceCode);

if(deviceFlowCodes != null)
{
Expand Down
33 changes: 18 additions & 15 deletions src/EntityFramework.Storage/src/Stores/PersistedGrantStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -48,7 +48,8 @@ public PersistedGrantStore(IPersistedGrantDbContext context, ILogger<PersistedGr
/// <returns></returns>
public virtual async Task StoreAsync(PersistedGrant token)
{
var existing = await Context.PersistedGrants.SingleOrDefaultAsync(x => x.Key == token.Key);
var existing = (await Context.PersistedGrants.Where(x => x.Key == token.Key).ToArrayAsync())
.SingleOrDefault(x => x.Key == token.Key);
if (existing == null)
{
Logger.LogDebug("{persistedGrantKey} not found in database", token.Key);
Expand Down Expand Up @@ -80,7 +81,8 @@ public virtual async Task StoreAsync(PersistedGrant token)
/// <returns></returns>
public virtual async Task<PersistedGrant> GetAsync(string key)
{
var persistedGrant = await Context.PersistedGrants.AsNoTracking().FirstOrDefaultAsync(x => x.Key == key);
var persistedGrant = (await Context.PersistedGrants.AsNoTracking().Where(x => x.Key == key).ToArrayAsync())
.SingleOrDefault(x => x.Key == key);
var model = persistedGrant?.ToModel();

Logger.LogDebug("{persistedGrantKey} found in database: {persistedGrantKeyFound}", key, model != null);
Expand All @@ -95,10 +97,11 @@ public virtual async Task<PersistedGrant> GetAsync(string key)
/// <returns></returns>
public virtual async Task<IEnumerable<PersistedGrant>> GetAllAsync(string subjectId)
{
var persistedGrants = await Context.PersistedGrants.Where(x => x.SubjectId == subjectId).AsNoTracking().ToListAsync();
var persistedGrants = (await Context.PersistedGrants.Where(x => x.SubjectId == subjectId).AsNoTracking().ToArrayAsync())
.Where(x => x.SubjectId == subjectId).ToArray();
var model = persistedGrants.Select(x => x.ToModel());

Logger.LogDebug("{persistedGrantCount} persisted grants found for {subjectId}", persistedGrants.Count, subjectId);
Logger.LogDebug("{persistedGrantCount} persisted grants found for {subjectId}", persistedGrants.Length, subjectId);

return model;
}
Expand All @@ -110,7 +113,8 @@ public virtual async Task<IEnumerable<PersistedGrant>> GetAllAsync(string subjec
/// <returns></returns>
public virtual async Task RemoveAsync(string key)
{
var persistedGrant = await Context.PersistedGrants.FirstOrDefaultAsync(x => x.Key == key);
var persistedGrant = (await Context.PersistedGrants.Where(x => x.Key == key).ToArrayAsync())
.SingleOrDefault(x => x.Key == key);
if (persistedGrant!= null)
{
Logger.LogDebug("removing {persistedGrantKey} persisted grant from database", key);
Expand Down Expand Up @@ -140,9 +144,10 @@ public virtual async Task RemoveAsync(string key)
/// <returns></returns>
public virtual async Task RemoveAllAsync(string subjectId, string clientId)
{
var persistedGrants = await Context.PersistedGrants.Where(x => x.SubjectId == subjectId && x.ClientId == clientId).ToListAsync();
var persistedGrants = (await Context.PersistedGrants.Where(x => x.SubjectId == subjectId && x.ClientId == clientId).ToArrayAsync())
.Where(x => x.SubjectId == subjectId && x.ClientId == clientId).ToArray();

Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}", persistedGrants.Count, subjectId, clientId);
Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}", persistedGrants.Length, subjectId, clientId);

Context.PersistedGrants.RemoveRange(persistedGrants);

Expand All @@ -152,7 +157,7 @@ public virtual async Task RemoveAllAsync(string subjectId, string clientId)
}
catch (DbUpdateConcurrencyException ex)
{
Logger.LogInformation("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}: {error}", persistedGrants.Count, subjectId, clientId, ex.Message);
Logger.LogInformation("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}: {error}", persistedGrants.Length, subjectId, clientId, ex.Message);
}
}

Expand All @@ -165,12 +170,10 @@ public virtual async Task RemoveAllAsync(string subjectId, string clientId)
/// <returns></returns>
public virtual async Task RemoveAllAsync(string subjectId, string clientId, string type)
{
var persistedGrants = await Context.PersistedGrants.Where(x =>
x.SubjectId == subjectId &&
x.ClientId == clientId &&
x.Type == type).ToListAsync();
var persistedGrants = (await Context.PersistedGrants.Where(x => x.SubjectId == subjectId && x.ClientId == clientId && x.Type == type).ToArrayAsync())
.Where(x => x.SubjectId == subjectId && x.ClientId == clientId && x.Type == type).ToArray();

Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}", persistedGrants.Count, subjectId, clientId, type);
Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}", persistedGrants.Length, subjectId, clientId, type);

Context.PersistedGrants.RemoveRange(persistedGrants);

Expand All @@ -180,7 +183,7 @@ public virtual async Task RemoveAllAsync(string subjectId, string clientId, stri
}
catch (DbUpdateConcurrencyException ex)
{
Logger.LogInformation("exception removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}: {error}", persistedGrants.Count, subjectId, clientId, type, ex.Message);
Logger.LogInformation("exception removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}: {error}", persistedGrants.Length, subjectId, clientId, type, ex.Message);
}
}
}
Expand Down
15 changes: 9 additions & 6 deletions src/EntityFramework.Storage/src/Stores/ResourceStore.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Copyright (c) Brock Allen & Dominick Baier. All rights reserved.
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.


Expand Down Expand Up @@ -54,7 +54,7 @@ public virtual async Task<ApiResource> FindApiResourceAsync(string name)
from apiResource in Context.ApiResources
where apiResource.Name == name
select apiResource;

var apis = query
.Include(x => x.Secrets)
.Include(x => x.Scopes)
Expand All @@ -63,7 +63,8 @@ from apiResource in Context.ApiResources
.Include(x => x.Properties)
.AsNoTracking();

var api = await apis.FirstOrDefaultAsync();
var api = (await apis.ToArrayAsync())
.SingleOrDefault(x => x.Name == name);

if (api != null)
{
Expand All @@ -88,7 +89,7 @@ public virtual async Task<IEnumerable<ApiResource>> FindApiResourcesByScopeAsync

var query =
from api in Context.ApiResources
where api.Scopes.Where(x=>names.Contains(x.Name)).Any()
where api.Scopes.Any(x => names.Contains(x.Name))
select api;

var apis = query
Expand All @@ -99,7 +100,8 @@ where api.Scopes.Where(x=>names.Contains(x.Name)).Any()
.Include(x => x.Properties)
.AsNoTracking();

var results = await apis.ToArrayAsync();
var results = (await apis.ToArrayAsync())
.Where(api => api.Scopes.Any(x => names.Contains(x.Name)));
var models = results.Select(x => x.ToModel()).ToArray();

Logger.LogDebug("Found {scopes} API scopes in database", models.SelectMany(x => x.Scopes).Select(x => x.Name));
Expand All @@ -126,7 +128,8 @@ where scopes.Contains(identityResource.Name)
.Include(x => x.Properties)
.AsNoTracking();

var results = await resources.ToArrayAsync();
var results = (await resources.ToArrayAsync())
.Where(x => scopes.Contains(x.Name));

Logger.LogDebug("Found {scopes} identity scopes in database", results.Select(x => x.Name));

Expand Down

0 comments on commit 94a3115

Please sign in to comment.