From 94a311564c72faf666172a33e8a41cc2ab8bf83c Mon Sep 17 00:00:00 2001 From: Brock Allen Date: Sun, 5 Jul 2020 10:16:17 -0400 Subject: [PATCH] validate filter values on db results (#4618) --- .../src/Stores/ClientStore.cs | 8 ++--- .../src/Stores/DeviceFlowStore.cs | 14 +++++--- .../src/Stores/PersistedGrantStore.cs | 33 ++++++++++--------- .../src/Stores/ResourceStore.cs | 15 +++++---- 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/src/EntityFramework.Storage/src/Stores/ClientStore.cs b/src/EntityFramework.Storage/src/Stores/ClientStore.cs index f35eced4a7..bcd0e92c88 100644 --- a/src/EntityFramework.Storage/src/Stores/ClientStore.cs +++ b/src/EntityFramework.Storage/src/Stores/ClientStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. // Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. @@ -52,10 +52,10 @@ public ClientStore(IConfigurationDbContext context, ILogger logger) public virtual async Task FindClientByIdAsync(string clientId) { IQueryable baseQuery = Context.Clients - .Where(x => x.ClientId == clientId) - .Take(1); + .Where(x => x.ClientId == clientId); - var client = await baseQuery.FirstOrDefaultAsync(); + var client = (await baseQuery.ToArrayAsync()) + .SingleOrDefault(x => x.ClientId == clientId); if (client == null) return null; await baseQuery.Include(x => x.AllowedCorsOrigins).SelectMany(c => c.AllowedCorsOrigins).LoadAsync(); diff --git a/src/EntityFramework.Storage/src/Stores/DeviceFlowStore.cs b/src/EntityFramework.Storage/src/Stores/DeviceFlowStore.cs index f69df6aa65..e4053009c1 100644 --- a/src/EntityFramework.Storage/src/Stores/DeviceFlowStore.cs +++ b/src/EntityFramework.Storage/src/Stores/DeviceFlowStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. // Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. @@ -74,7 +74,8 @@ public virtual async Task StoreDeviceAuthorizationAsync(string deviceCode, strin /// public virtual async Task FindByUserCodeAsync(string userCode) { - var deviceFlowCodes = await Context.DeviceFlowCodes.AsNoTracking().FirstOrDefaultAsync(x => x.UserCode == userCode); + var deviceFlowCodes = (await Context.DeviceFlowCodes.AsNoTracking().Where(x => x.UserCode == userCode).ToArrayAsync()) + .SingleOrDefault(x => x.UserCode == userCode); var model = ToModel(deviceFlowCodes?.Data); Logger.LogDebug("{userCode} found in database: {userCodeFound}", userCode, model != null); @@ -89,7 +90,8 @@ public virtual async Task FindByUserCodeAsync(string userCode) /// public virtual async Task FindByDeviceCodeAsync(string deviceCode) { - var deviceFlowCodes = await Context.DeviceFlowCodes.AsNoTracking().FirstOrDefaultAsync(x => x.DeviceCode == deviceCode); + var deviceFlowCodes = (await Context.DeviceFlowCodes.AsNoTracking().Where(x => x.DeviceCode == deviceCode).ToArrayAsync()) + .SingleOrDefault(x => x.DeviceCode == deviceCode); var model = ToModel(deviceFlowCodes?.Data); Logger.LogDebug("{deviceCode} found in database: {deviceCodeFound}", deviceCode, model != null); @@ -105,7 +107,8 @@ public virtual async Task FindByDeviceCodeAsync(string deviceCode) /// public virtual async Task UpdateByUserCodeAsync(string userCode, DeviceCode data) { - var existing = await Context.DeviceFlowCodes.SingleOrDefaultAsync(x => x.UserCode == userCode); + var existing = (await Context.DeviceFlowCodes.Where(x => x.UserCode == userCode).ToArrayAsync()) + .SingleOrDefault(x => x.UserCode == userCode); if (existing == null) { Logger.LogError("{userCode} not found in database", userCode); @@ -135,7 +138,8 @@ public virtual async Task UpdateByUserCodeAsync(string userCode, DeviceCode data /// public virtual async Task RemoveByDeviceCodeAsync(string deviceCode) { - var deviceFlowCodes = await Context.DeviceFlowCodes.FirstOrDefaultAsync(x => x.DeviceCode == deviceCode); + var deviceFlowCodes = (await Context.DeviceFlowCodes.Where(x => x.DeviceCode == deviceCode).ToArrayAsync()) + .SingleOrDefault(x => x.DeviceCode == deviceCode); if(deviceFlowCodes != null) { diff --git a/src/EntityFramework.Storage/src/Stores/PersistedGrantStore.cs b/src/EntityFramework.Storage/src/Stores/PersistedGrantStore.cs index b5deca5825..4254ead07e 100644 --- a/src/EntityFramework.Storage/src/Stores/PersistedGrantStore.cs +++ b/src/EntityFramework.Storage/src/Stores/PersistedGrantStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. // Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. @@ -48,7 +48,8 @@ public PersistedGrantStore(IPersistedGrantDbContext context, ILogger public virtual async Task StoreAsync(PersistedGrant token) { - var existing = await Context.PersistedGrants.SingleOrDefaultAsync(x => x.Key == token.Key); + var existing = (await Context.PersistedGrants.Where(x => x.Key == token.Key).ToArrayAsync()) + .SingleOrDefault(x => x.Key == token.Key); if (existing == null) { Logger.LogDebug("{persistedGrantKey} not found in database", token.Key); @@ -80,7 +81,8 @@ public virtual async Task StoreAsync(PersistedGrant token) /// public virtual async Task GetAsync(string key) { - var persistedGrant = await Context.PersistedGrants.AsNoTracking().FirstOrDefaultAsync(x => x.Key == key); + var persistedGrant = (await Context.PersistedGrants.AsNoTracking().Where(x => x.Key == key).ToArrayAsync()) + .SingleOrDefault(x => x.Key == key); var model = persistedGrant?.ToModel(); Logger.LogDebug("{persistedGrantKey} found in database: {persistedGrantKeyFound}", key, model != null); @@ -95,10 +97,11 @@ public virtual async Task GetAsync(string key) /// public virtual async Task> GetAllAsync(string subjectId) { - var persistedGrants = await Context.PersistedGrants.Where(x => x.SubjectId == subjectId).AsNoTracking().ToListAsync(); + var persistedGrants = (await Context.PersistedGrants.Where(x => x.SubjectId == subjectId).AsNoTracking().ToArrayAsync()) + .Where(x => x.SubjectId == subjectId).ToArray(); var model = persistedGrants.Select(x => x.ToModel()); - Logger.LogDebug("{persistedGrantCount} persisted grants found for {subjectId}", persistedGrants.Count, subjectId); + Logger.LogDebug("{persistedGrantCount} persisted grants found for {subjectId}", persistedGrants.Length, subjectId); return model; } @@ -110,7 +113,8 @@ public virtual async Task> GetAllAsync(string subjec /// public virtual async Task RemoveAsync(string key) { - var persistedGrant = await Context.PersistedGrants.FirstOrDefaultAsync(x => x.Key == key); + var persistedGrant = (await Context.PersistedGrants.Where(x => x.Key == key).ToArrayAsync()) + .SingleOrDefault(x => x.Key == key); if (persistedGrant!= null) { Logger.LogDebug("removing {persistedGrantKey} persisted grant from database", key); @@ -140,9 +144,10 @@ public virtual async Task RemoveAsync(string key) /// public virtual async Task RemoveAllAsync(string subjectId, string clientId) { - var persistedGrants = await Context.PersistedGrants.Where(x => x.SubjectId == subjectId && x.ClientId == clientId).ToListAsync(); + var persistedGrants = (await Context.PersistedGrants.Where(x => x.SubjectId == subjectId && x.ClientId == clientId).ToArrayAsync()) + .Where(x => x.SubjectId == subjectId && x.ClientId == clientId).ToArray(); - Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}", persistedGrants.Count, subjectId, clientId); + Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}", persistedGrants.Length, subjectId, clientId); Context.PersistedGrants.RemoveRange(persistedGrants); @@ -152,7 +157,7 @@ public virtual async Task RemoveAllAsync(string subjectId, string clientId) } catch (DbUpdateConcurrencyException ex) { - Logger.LogInformation("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}: {error}", persistedGrants.Count, subjectId, clientId, ex.Message); + Logger.LogInformation("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}: {error}", persistedGrants.Length, subjectId, clientId, ex.Message); } } @@ -165,12 +170,10 @@ public virtual async Task RemoveAllAsync(string subjectId, string clientId) /// public virtual async Task RemoveAllAsync(string subjectId, string clientId, string type) { - var persistedGrants = await Context.PersistedGrants.Where(x => - x.SubjectId == subjectId && - x.ClientId == clientId && - x.Type == type).ToListAsync(); + var persistedGrants = (await Context.PersistedGrants.Where(x => x.SubjectId == subjectId && x.ClientId == clientId && x.Type == type).ToArrayAsync()) + .Where(x => x.SubjectId == subjectId && x.ClientId == clientId && x.Type == type).ToArray(); - Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}", persistedGrants.Count, subjectId, clientId, type); + Logger.LogDebug("removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}", persistedGrants.Length, subjectId, clientId, type); Context.PersistedGrants.RemoveRange(persistedGrants); @@ -180,7 +183,7 @@ public virtual async Task RemoveAllAsync(string subjectId, string clientId, stri } catch (DbUpdateConcurrencyException ex) { - Logger.LogInformation("exception removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}: {error}", persistedGrants.Count, subjectId, clientId, type, ex.Message); + Logger.LogInformation("exception removing {persistedGrantCount} persisted grants from database for subject {subjectId}, clientId {clientId}, grantType {persistedGrantType}: {error}", persistedGrants.Length, subjectId, clientId, type, ex.Message); } } } diff --git a/src/EntityFramework.Storage/src/Stores/ResourceStore.cs b/src/EntityFramework.Storage/src/Stores/ResourceStore.cs index 1ac4c6ceb1..b32ac62a8e 100644 --- a/src/EntityFramework.Storage/src/Stores/ResourceStore.cs +++ b/src/EntityFramework.Storage/src/Stores/ResourceStore.cs @@ -1,4 +1,4 @@ -// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. +// Copyright (c) Brock Allen & Dominick Baier. All rights reserved. // Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. @@ -54,7 +54,7 @@ public virtual async Task FindApiResourceAsync(string name) from apiResource in Context.ApiResources where apiResource.Name == name select apiResource; - + var apis = query .Include(x => x.Secrets) .Include(x => x.Scopes) @@ -63,7 +63,8 @@ from apiResource in Context.ApiResources .Include(x => x.Properties) .AsNoTracking(); - var api = await apis.FirstOrDefaultAsync(); + var api = (await apis.ToArrayAsync()) + .SingleOrDefault(x => x.Name == name); if (api != null) { @@ -88,7 +89,7 @@ public virtual async Task> FindApiResourcesByScopeAsync var query = from api in Context.ApiResources - where api.Scopes.Where(x=>names.Contains(x.Name)).Any() + where api.Scopes.Any(x => names.Contains(x.Name)) select api; var apis = query @@ -99,7 +100,8 @@ where api.Scopes.Where(x=>names.Contains(x.Name)).Any() .Include(x => x.Properties) .AsNoTracking(); - var results = await apis.ToArrayAsync(); + var results = (await apis.ToArrayAsync()) + .Where(api => api.Scopes.Any(x => names.Contains(x.Name))); var models = results.Select(x => x.ToModel()).ToArray(); Logger.LogDebug("Found {scopes} API scopes in database", models.SelectMany(x => x.Scopes).Select(x => x.Name)); @@ -126,7 +128,8 @@ where scopes.Contains(identityResource.Name) .Include(x => x.Properties) .AsNoTracking(); - var results = await resources.ToArrayAsync(); + var results = (await resources.ToArrayAsync()) + .Where(x => scopes.Contains(x.Name)); Logger.LogDebug("Found {scopes} identity scopes in database", results.Select(x => x.Name));