Skip to content

Commit

Permalink
IAuthTokenProvider: support for token-based authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
satano committed Dec 19, 2019
1 parent 80b1e9c commit baf8dc6
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 27 deletions.
14 changes: 14 additions & 0 deletions src/IAuthTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
namespace Kros.KORM
{
/// <summary>
/// Support for token-based authentication for SQL Server.
/// </summary>
public interface IAuthTokenProvider
{
/// <summary>
/// Returns authentication token, or <see langword="null" /> value, if token can not be obtained.
/// </summary>
/// <returns>Authentication token.</returns>
string GetToken();
}
}
23 changes: 13 additions & 10 deletions src/Query/Providers/QueryProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -164,9 +165,7 @@ public QueryProvider(
}

private void InitSqlExpressionVisitor(ISqlExpressionVisitorFactory sqlGeneratorFactory)
{
_sqlExpressionVisitor = new Lazy<ISqlExpressionVisitor>(() => sqlGeneratorFactory.CreateVisitor(Connection));
}
=> _sqlExpressionVisitor = new Lazy<ISqlExpressionVisitor>(() => sqlGeneratorFactory.CreateVisitor(GetConnection()));

private TransactionHelper TransactionHelperFactory()
{
Expand Down Expand Up @@ -211,7 +210,7 @@ public void SetParameterDbType(DbParameter parameter, string tableName, string c
private TableSchema LoadTableSchema(string tableName)
{
IDatabaseSchemaLoader schemaLoader = GetSchemaLoader();
TableSchema tableSchema = schemaLoader.LoadTableSchema(Connection, tableName);
TableSchema tableSchema = schemaLoader.LoadTableSchema(GetConnection(), tableName);
return tableSchema
?? throw new InvalidOperationException(string.Format(Resources.QueryProviderCouldNotGetTableSchema, tableName));
}
Expand Down Expand Up @@ -546,7 +545,7 @@ public ITransaction BeginTransaction(IsolationLevel isolationLevel) =>
/// <inheritdoc/>
public IIdGenerator CreateIdGenerator(string tableName, int batchSize)
{
var connection = (Connection as ICloneable).Clone() as DbConnection;
var connection = (GetConnection() as ICloneable).Clone() as DbConnection;
try
{
connection.Open();
Expand Down Expand Up @@ -640,6 +639,10 @@ public TResult Execute<TResult>(Expression expression)
[Obsolete("Use GetConnection() method.")]
protected DbConnection Connection => GetConnection();

/// <summary>
/// Returns (creates if needed) connection.
/// </summary>
/// <returns><see cref="DbConnection"/> instance.</returns>
protected virtual DbConnection GetConnection()
{
if (_connection == null)
Expand All @@ -650,11 +653,10 @@ protected virtual DbConnection GetConnection()
return _connection;
}

private Data.ConnectionHelper OpenConnection()
{
return new Data.ConnectionHelper(Connection);
}
private Data.ConnectionHelper OpenConnection() => new Data.ConnectionHelper(GetConnection());

[SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities",
Justification = "Query is external or generated.")]
private DbCommandInfo CreateCommand<T>(IQuery<T> query)
{
DbCommand command = _transactionHelper.Value.CreateCommand();
Expand All @@ -677,6 +679,8 @@ private DbCommandInfo CreateCommand<T>(IQuery<T> query)
protected internal void SetQueryFilter<T>(IQuery<T> query, ISqlExpressionVisitor sqlVisitor)
=> (query as IQueryBaseInternal).ApplyQueryFilter(_databaseMapper, sqlVisitor);

[SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities",
Justification = "Query is external or generated.")]
private DbCommand CreateCommand(string commandText, CommandParameterCollection parameters)
{
DbCommand command = _transactionHelper.Value.CreateCommand();
Expand All @@ -689,7 +693,6 @@ private DbCommand CreateCommand(string commandText, CommandParameterCollection p
AddCommandParameter(command, parameter);
}
}

return command;
}

Expand Down
48 changes: 42 additions & 6 deletions src/Query/Providers/SqlServerQueryProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace Kros.KORM.Query
/// <seealso cref="Kros.KORM.Query.QueryProvider" />
public class SqlServerQueryProvider : QueryProvider
{
private readonly IAuthTokenProvider _tokenProvider;

/// <summary>
/// Initializes a new instance of the <see cref="SqlServerQueryProvider"/> class.
/// </summary>
Expand All @@ -25,14 +27,17 @@ public class SqlServerQueryProvider : QueryProvider
/// <param name="modelBuilder">The model builder.</param>
/// <param name="logger">The logger.</param>
/// <param name="databaseMapper">The Database mapper.</param>
/// <param name="tokenProvider">Provider to support token-based authentication.</param>
public SqlServerQueryProvider(
KormConnectionSettings connectionString,
ISqlExpressionVisitorFactory sqlGeneratorFactory,
IModelBuilder modelBuilder,
ILogger logger,
IDatabaseMapper databaseMapper)
IDatabaseMapper databaseMapper,
IAuthTokenProvider tokenProvider)
: base(connectionString, sqlGeneratorFactory, modelBuilder, logger, databaseMapper)
{
_tokenProvider = tokenProvider;
}

/// <summary>
Expand All @@ -43,21 +48,44 @@ public SqlServerQueryProvider(
/// <param name="modelBuilder">The model builder.</param>
/// <param name="logger">The logger.</param>
/// <param name="databaseMapper">The Database mapper.</param>
/// <param name="tokenProvider">Provider to support token-based authentication.</param>
public SqlServerQueryProvider(
DbConnection connection,
ISqlExpressionVisitorFactory sqlGeneratorFactory,
IModelBuilder modelBuilder,
ILogger logger,
IDatabaseMapper databaseMapper)
IDatabaseMapper databaseMapper,
IAuthTokenProvider tokenProvider)
: base(connection, sqlGeneratorFactory, modelBuilder, logger, databaseMapper)
{
_tokenProvider = tokenProvider;
}

/// <summary>
/// Returns <see cref="DbProviderFactory"/> for current provider.
/// </summary>
public override DbProviderFactory DbProviderFactory => SqlClientFactory.Instance;

/// <summary>
/// Returns (creates if needed) connection. If <see cref="IAuthTokenProvider"/> was setup in constructor,
/// it is used to set the <see cref="SqlConnection.AccessToken">AccessToken</see> on connection.
/// </summary>
/// <returns><see cref="DbConnection"/> instance.</returns>
protected override DbConnection GetConnection()
{
var connection = (SqlConnection)base.GetConnection();
SetAccessToken(connection);
return connection;
}

private void SetAccessToken(SqlConnection connection)
{
if (_tokenProvider != null)
{
connection.AccessToken = _tokenProvider.GetToken();
}
}

/// <summary>
/// Creates instance of <see cref="IBulkInsert" />.
/// </summary>
Expand All @@ -69,11 +97,11 @@ public override IBulkInsert CreateBulkInsert()
var transaction = GetCurrentTransaction();
if (IsExternalConnection || transaction != null)
{
return new SqlServerBulkInsert(Connection as SqlConnection, transaction as SqlTransaction);
return new SqlServerBulkInsert(GetConnection() as SqlConnection, transaction as SqlTransaction);
}
else
{
return new SqlServerBulkInsert(ConnectionString);
return new SqlServerBulkInsert(CreateConnection());
}
}

Expand All @@ -89,14 +117,22 @@ public override IBulkUpdate CreateBulkUpdate()

if (IsExternalConnection || transaction != null)
{
return new SqlServerBulkUpdate(Connection as SqlConnection, transaction as SqlTransaction);
return new SqlServerBulkUpdate(GetConnection() as SqlConnection, transaction as SqlTransaction);
}
else
{
return new SqlServerBulkUpdate(ConnectionString);
return new SqlServerBulkUpdate(CreateConnection());
}
}

private SqlConnection CreateConnection()
{
var connection = (SqlConnection)DbProviderFactory.CreateConnection();
connection.ConnectionString = ConnectionString;
SetAccessToken(connection);
return connection;
}

/// <summary>
/// Returns instance of <see cref="SqlServerSchemaLoader"/>.
/// </summary>
Expand Down
6 changes: 4 additions & 2 deletions src/Query/Providers/SqlServerQueryProviderFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ public IQueryProvider Create(DbConnection connection, IModelBuilder modelBuilder
new SqlServerSqlExpressionVisitorFactory(databaseMapper),
modelBuilder,
new Logger(),
databaseMapper);
databaseMapper,
null);

/// <summary>
/// Creates the SqlServer query provider.
Expand All @@ -47,7 +48,8 @@ public IQueryProvider Create(
new SqlServerSqlExpressionVisitorFactory(databaseMapper),
modelBuilder,
new Logger(),
databaseMapper);
databaseMapper,
null);

/// <summary>
/// Registers instance of this type to <see cref="QueryProviderFactories"/>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ private IQuery<Foo> CreateFooQuery()
new SqlServerSqlExpressionVisitorFactory(new DatabaseMapper(new ConventionModelMapper())),
Substitute.For<IModelBuilder>(),
new Logger(),
Substitute.For<IDatabaseMapper>()));
Substitute.For<IDatabaseMapper>(),
null));

return query;
}
Expand Down Expand Up @@ -285,7 +286,8 @@ private IQuery<FooIdentity> CreateFooIdentityQuery()
new SqlServerSqlExpressionVisitorFactory(new DatabaseMapper(new ConventionModelMapper())),
Substitute.For<IModelBuilder>(),
new Logger(),
Substitute.For<IDatabaseMapper>()));
Substitute.For<IDatabaseMapper>(),
null));

return query;
}
Expand Down
11 changes: 5 additions & 6 deletions tests/Kros.KORM.UnitTests/Query/Providers/QueryProviderShould.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ private TestQueryProvider(DbConnection externalConnection)

public override DbProviderFactory DbProviderFactory => _dbProviderFactory;

public void CreateConnection()
{
var connection = Connection;
}
public void CreateConnection() => GetConnection();

public override IBulkInsert CreateBulkInsert()
{
Expand Down Expand Up @@ -511,15 +508,17 @@ private static SqlServerQueryProvider CreateQueryProvider(SqlConnection connecti
Substitute.For<ISqlExpressionVisitorFactory>(),
new ModelBuilder(Database.DefaultModelFactory),
Substitute.For<ILogger>(),
Substitute.For<IDatabaseMapper>());
Substitute.For<IDatabaseMapper>(),
null);

private static SqlServerQueryProvider CreateQueryProvider(string connectionString)
=> new SqlServerQueryProvider(
new KormConnectionSettings() { ConnectionString = connectionString },
Substitute.For<ISqlExpressionVisitorFactory>(),
new ModelBuilder(Database.DefaultModelFactory),
Substitute.For<ILogger>(),
Substitute.For<IDatabaseMapper>());
Substitute.For<IDatabaseMapper>(),
null);

#endregion
}
Expand Down
3 changes: 2 additions & 1 deletion tests/Kros.KORM.UnitTests/Query/QueryShould.cs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ private IQuery<Person> CreateQuery()
new SqlServerSqlExpressionVisitorFactory(mapper),
Substitute.For<IModelBuilder>(),
new Logger(),
Substitute.For<IDatabaseMapper>()));
Substitute.For<IDatabaseMapper>(),
null));

return query;
}
Expand Down

0 comments on commit baf8dc6

Please sign in to comment.