Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Make QueryProvider.Connection virtual #55

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions src/Data/ConnectionHelper.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Text;

namespace Kros.KORM.Data
{
Expand Down Expand Up @@ -45,9 +43,6 @@ protected virtual void Dispose(bool disposing)
}
}

public void Dispose()
{
Dispose(true);
}
public void Dispose() => Dispose(true);
}
}
63 changes: 36 additions & 27 deletions src/Data/TransactionHelper.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Kros.Utils;
using Kros.KORM.Properties;
using Kros.Utils;
using System;
using System.Collections.Generic;
using System.Data;
Expand All @@ -13,34 +14,35 @@ namespace Kros.KORM.Data
internal class TransactionHelper
{
public const IsolationLevel DefaultIsolationLevel = IsolationLevel.ReadCommitted;
private const int TIMEOUT_DEFAULT = 30;

private readonly DbConnection _connection;
private Transaction _topTransaction;
private bool _canCommit = true;
private readonly Stack<ITransaction> _transactions = new Stack<ITransaction>();
private const int DefaultCommandTimeout = 30;

#region Nested types

private class Transaction : ITransaction
{
private readonly ConnectionHelper _connectionHelper;
private readonly DbConnection _connection;
private readonly bool _closeConnection;
private readonly Lazy<DbTransaction> _transaction;
private readonly TransactionHelper _transactionHelper;
private bool _wasCommitOrRollback = false;

public Transaction(TransactionHelper transactionHelper, ConnectionHelper connectionHelper, IsolationLevel isolationLevel)
public Transaction(
TransactionHelper transactionHelper,
DbConnection connection,
bool closeConnection,
IsolationLevel isolationLevel)
{
_connectionHelper = connectionHelper;
_transaction = new Lazy<DbTransaction>(() => connectionHelper.Connection.BeginTransaction(isolationLevel));
_transactionHelper = transactionHelper;
_connection = connection;
_closeConnection = closeConnection;
_transaction = new Lazy<DbTransaction>(() => _connection.BeginTransaction(isolationLevel));
}

public void Commit()
{
_wasCommitOrRollback = true;
if (_transactionHelper.CanCommitTransaction)
{
_wasCommitOrRollback = true;
_transaction.Value.Commit();
_transactionHelper.EndTransaction(true);
}
Expand All @@ -53,7 +55,7 @@ public void Rollback()
_transactionHelper.EndTransaction(false);
}

public int CommandTimeout { get; set; } = TIMEOUT_DEFAULT;
public int CommandTimeout { get; set; } = DefaultCommandTimeout;

public static implicit operator DbTransaction(Transaction transaction)
=> transaction?._transaction.Value;
Expand All @@ -69,7 +71,10 @@ public void Dispose()
{
_transaction.Value.Dispose();
}
_connectionHelper.Dispose();
if (_closeConnection)
{
_connection.Close();
}
}
}

Expand All @@ -87,8 +92,14 @@ public NestedTransaction(TransactionHelper transactionHelper, int timeout)

public void Commit()
{
_wasCommitOrRollback = true;
_transactionHelper.EndTransaction(true);
}

public void Rollback()
{
_wasCommitOrRollback = true;
_transactionHelper.EndTransaction(false);
}

public void Dispose()
Expand All @@ -99,39 +110,39 @@ public void Dispose()
}
}

public void Rollback()
{
_transactionHelper.EndTransaction(false);
_wasCommitOrRollback = true;
}

public int CommandTimeout
{
get => _timeout;
set { }
get => DefaultCommandTimeout;
set => throw new InvalidOperationException(Resources.NestedTransactionCommandTimeoutIsReadonly);
}
}

#endregion

public TransactionHelper(DbConnection connection)
private readonly DbConnection _connection;
private readonly bool _closeConnection;
private Transaction _topTransaction;
private bool _canCommit = true;
private readonly Stack<ITransaction> _transactions = new Stack<ITransaction>();

public TransactionHelper(DbConnection connection, bool closeConnection)
{
_connection = Check.NotNull(connection, nameof(connection));
_closeConnection = closeConnection;
}

public ITransaction BeginTransaction(IsolationLevel isolationLevel)
{
if (_transactions.Count == 0)
{
_topTransaction = new Transaction(this, new ConnectionHelper(_connection), isolationLevel);
_topTransaction = new Transaction(this, _connection, _closeConnection, isolationLevel);
_transactions.Push(_topTransaction);
_canCommit = true;
}
else
{
_transactions.Push(new NestedTransaction(this, _topTransaction.CommandTimeout));
}

return _transactions.Peek();
}

Expand All @@ -145,7 +156,6 @@ private void EndTransaction(bool success)
{
_canCommit &= success;
_transactions.Pop();

if (!_transactions.Any())
{
_topTransaction = null;
Expand All @@ -160,7 +170,6 @@ public DbCommand CreateCommand()
cmd.Transaction = _topTransaction;
cmd.CommandTimeout = _topTransaction.CommandTimeout;
}

return cmd;
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/Database.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,14 @@ public partial class Database : IDatabase
/// <summary>
/// Builder for creating <see cref="IDatabase"/> instance.
/// </summary>
[Obsolete("Use CreateBuilder method.")]
public static IDatabaseBuilder Builder => new DatabaseBuilder();

/// <summary>
/// Creates a builder for creating <see cref="IDatabase"/> instance.
/// </summary>
public static IDatabaseBuilder CreateBuilder() => new DatabaseBuilder();

#endregion

#region Private fields
Expand Down
14 changes: 14 additions & 0 deletions src/IAuthTokenProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
namespace Kros.KORM
{
/// <summary>
/// Support for token-based authentication for SQL Server.
/// </summary>
public interface IAuthTokenProvider
{
/// <summary>
/// Returns authentication token, or <see langword="null" /> value, if token can not be obtained.
/// </summary>
/// <returns>Authentication token.</returns>
string GetToken();
}
}
47 changes: 29 additions & 18 deletions src/Query/Providers/QueryProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -122,7 +123,7 @@ protected QueryProvider(
InitSqlExpressionVisitor(Check.NotNull(sqlGeneratorFactory, nameof(sqlGeneratorFactory)));
IsExternalConnection = false;
_modelBuilder = Check.NotNull(modelBuilder, nameof(modelBuilder));
_transactionHelper = new Lazy<TransactionHelper>(() => new TransactionHelper(Connection));
_transactionHelper = new Lazy<TransactionHelper>(TransactionHelperFactory);
}

/// <summary>
Expand All @@ -147,11 +148,17 @@ protected QueryProvider(

InitSqlExpressionVisitor(Check.NotNull(sqlGeneratorFactory, nameof(sqlGeneratorFactory)));
IsExternalConnection = true;
_transactionHelper = new Lazy<TransactionHelper>(() => new TransactionHelper(Connection));
_transactionHelper = new Lazy<TransactionHelper>(TransactionHelperFactory);
}

private void InitSqlExpressionVisitor(ISqlExpressionVisitorFactory sqlGeneratorFactory)
=> _sqlExpressionVisitor = new Lazy<ISqlExpressionVisitor>(() => sqlGeneratorFactory.CreateVisitor(Connection));
=> _sqlExpressionVisitor = new Lazy<ISqlExpressionVisitor>(() => sqlGeneratorFactory.CreateVisitor(GetConnection()));

private TransactionHelper TransactionHelperFactory()
{
DbConnection connection = GetConnection();
return new TransactionHelper(connection, !connection.State.HasFlag(ConnectionState.Open));
}

#endregion

Expand Down Expand Up @@ -190,7 +197,7 @@ public void SetParameterDbType(DbParameter parameter, string tableName, string c
private TableSchema LoadTableSchema(string tableName)
{
IDatabaseSchemaLoader schemaLoader = GetSchemaLoader();
TableSchema tableSchema = schemaLoader.LoadTableSchema(Connection, tableName);
TableSchema tableSchema = schemaLoader.LoadTableSchema(GetConnection(), tableName);
return tableSchema
?? throw new InvalidOperationException(string.Format(Resources.QueryProviderCouldNotGetTableSchema, tableName));
}
Expand Down Expand Up @@ -541,7 +548,7 @@ public IIdGeneratorsForDatabaseInit GetIdGeneratorsForDatabaseInit()

private IDbConnection GetConnectionForIdGenerator()
{
var connection = (Connection as ICloneable).Clone() as DbConnection;
var connection = (GetConnection() as ICloneable).Clone() as DbConnection;
try
{
connection.Open();
Expand Down Expand Up @@ -631,24 +638,27 @@ public TResult Execute<TResult>(Expression expression)
/// Vráti spojenie na databázu s ktorou trieda pracuje. Ak trieda bola vytvorená iba so zadaným
/// connection string-om, je vytvorené nové spojenie.
/// </summary>
protected DbConnection Connection
[Obsolete("Use GetConnection() method.")]
protected DbConnection Connection => GetConnection();

/// <summary>
/// Returns (creates if needed) connection.
/// </summary>
/// <returns><see cref="DbConnection"/> instance.</returns>
protected virtual DbConnection GetConnection()
{
get
if (_connection == null)
{
if (_connection == null)
{
_connection = DbProviderFactory.CreateConnection();
_connection.ConnectionString = _connectionSettings.ConnectionString;
}
return _connection;
_connection = DbProviderFactory.CreateConnection();
_connection.ConnectionString = _connectionSettings.ConnectionString;
}
return _connection;
}

private Data.ConnectionHelper OpenConnection()
{
return new Data.ConnectionHelper(Connection);
}
private Data.ConnectionHelper OpenConnection() => new Data.ConnectionHelper(GetConnection());

[SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities",
Justification = "Query is external or generated.")]
private DbCommandInfo CreateCommand<T>(IQuery<T> query)
{
DbCommand command = _transactionHelper.Value.CreateCommand();
Expand All @@ -671,6 +681,8 @@ private DbCommandInfo CreateCommand<T>(IQuery<T> query)
protected internal void SetQueryFilter<T>(IQuery<T> query, ISqlExpressionVisitor sqlVisitor)
=> (query as IQueryBaseInternal).ApplyQueryFilter(_databaseMapper, sqlVisitor);

[SuppressMessage("Security", "CA2100:Review SQL queries for security vulnerabilities",
Justification = "Query is external or generated.")]
private DbCommand CreateCommand(string commandText, CommandParameterCollection parameters)
{
DbCommand command = _transactionHelper.Value.CreateCommand();
Expand All @@ -683,7 +695,6 @@ private DbCommand CreateCommand(string commandText, CommandParameterCollection p
AddCommandParameter(command, parameter);
}
}

return command;
}

Expand Down
Loading