diff --git a/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs b/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs index 20e0717f6..df92e081a 100644 --- a/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs +++ b/src/MySqlConnector/Core/ConcatenatedCommandPayloadCreator.cs @@ -17,7 +17,7 @@ public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDict // ConcatenatedCommandPayloadCreator is only used by MySqlBatch, and MySqlBatchCommand doesn't expose attributes, // so just write an empty attribute set if the server needs it. - if (commandListPosition.CommandAt(commandListPosition.CommandIndex).Connection!.Session.SupportsQueryAttributes) + if (commandListPosition.CommandAt(commandListPosition.CommandIndex).Connection!.Session.Context.SupportsQueryAttributes) { // attribute count writer.WriteLengthEncodedInteger(0); diff --git a/src/MySqlConnector/Core/ConnectionPool.cs b/src/MySqlConnector/Core/ConnectionPool.cs index 17227b631..0fb5f2e4e 100644 --- a/src/MySqlConnector/Core/ConnectionPool.cs +++ b/src/MySqlConnector/Core/ConnectionPool.cs @@ -65,7 +65,7 @@ public async ValueTask GetSessionAsync(MySqlConnection connection } else { - if (ConnectionSettings.ConnectionReset || session.DatabaseOverride is not null) + if (ConnectionSettings.ConnectionReset || !session.Context.IsInitialDatabase()) { if (timeoutMilliseconds != 0) session.SetTimeout(Math.Max(1, timeoutMilliseconds - Utility.GetElapsedMilliseconds(startingTimestamp))); diff --git a/src/MySqlConnector/Core/Context.cs b/src/MySqlConnector/Core/Context.cs new file mode 100644 index 000000000..74448d4f5 --- /dev/null +++ b/src/MySqlConnector/Core/Context.cs @@ -0,0 +1,29 @@ +using MySqlConnector.Protocol; + +namespace MySqlConnector.Core; + +internal sealed class Context +{ + public Context(ProtocolCapabilities protocolCapabilities, string? database, int connectionId) + { + SupportsDeprecateEof = (protocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0; + SupportsCachedPreparedMetadata = (protocolCapabilities & ProtocolCapabilities.MariaDbCacheMetadata) != 0; + SupportsQueryAttributes = (protocolCapabilities & ProtocolCapabilities.QueryAttributes) != 0; + SupportsSessionTrack = (protocolCapabilities & ProtocolCapabilities.SessionTrack) != 0; + ConnectionId = connectionId; + Database = database; + m_initialDatabase = database; + } + + public bool SupportsDeprecateEof { get; } + public bool SupportsQueryAttributes { get; } + public bool SupportsSessionTrack { get; } + public bool SupportsCachedPreparedMetadata { get; } + public string? ClientCharset { get; set; } + + public string? Database { get; set; } + private readonly string? m_initialDatabase; + public bool IsInitialDatabase() => string.Equals(m_initialDatabase, Database, StringComparison.Ordinal); + + public int ConnectionId { get; set; } +} diff --git a/src/MySqlConnector/Core/ResultSet.cs b/src/MySqlConnector/Core/ResultSet.cs index a0887aa2c..8d40ee6fb 100644 --- a/src/MySqlConnector/Core/ResultSet.cs +++ b/src/MySqlConnector/Core/ResultSet.cs @@ -38,7 +38,7 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) var firstByte = payload.HeaderByte; if (firstByte == OkPayload.Signature) { - var ok = OkPayload.Create(payload.Span, Session.SupportsDeprecateEof, Session.SupportsSessionTrack); + var ok = OkPayload.Create(payload.Span, Session.Context); // if we've read a result set header then this is a SELECT statement, so we shouldn't overwrite RecordsAffected // (which should be -1 for SELECT) unless the server reports a non-zero count @@ -48,8 +48,6 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) if (ok.LastInsertId != 0) Command?.SetLastInsertedId((long) ok.LastInsertId); WarningCount = ok.WarningCount; - if (ok.NewSchema is not null) - Connection.Session.DatabaseOverride = ok.NewSchema; m_columnDefinitions = default; State = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData @@ -109,7 +107,7 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) } else { - var columnCountPacket = ColumnCountPayload.Create(payload.Span, Session.SupportsCachedPreparedMetadata); + var columnCountPacket = ColumnCountPayload.Create(payload.Span, Session.Context.SupportsCachedPreparedMetadata); var columnCount = columnCountPacket.ColumnCount; if (!columnCountPacket.MetadataFollows) { @@ -132,7 +130,7 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) m_columnDefinitions = m_columnDefinitionPayloadCache.AsMemory(0, columnCount); // if the server supports metadata caching but has re-sent it, something has changed since last prepare/execution and we need to update the columns - var preparedColumns = Session.SupportsCachedPreparedMetadata ? DataReader.LastUsedPreparedStatement?.Columns : null; + var preparedColumns = Session.Context.SupportsCachedPreparedMetadata ? DataReader.LastUsedPreparedStatement?.Columns : null; for (var column = 0; column < columnCount; column++) { @@ -156,7 +154,7 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) } } - if (!Session.SupportsDeprecateEof) + if (!Session.Context.SupportsDeprecateEof) { payload = await Session.ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); _ = EofPayload.Create(payload.Span); @@ -252,13 +250,13 @@ public async Task ReadAsync(IOBehavior ioBehavior, CancellationToken cance if (payload.HeaderByte == EofPayload.Signature) { - if (Session.SupportsDeprecateEof && OkPayload.IsOk(payload.Span, Session.SupportsDeprecateEof)) + if (Session.Context.SupportsDeprecateEof && OkPayload.IsOk(payload.Span, Session.Context)) { - var ok = OkPayload.Create(payload.Span, Session.SupportsDeprecateEof, Session.SupportsSessionTrack); + var ok = OkPayload.Create(payload.Span, Session.Context); BufferState = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData; return null; } - if (!Session.SupportsDeprecateEof && EofPayload.IsEof(payload)) + if (!Session.Context.SupportsDeprecateEof && EofPayload.IsEof(payload)) { var eof = EofPayload.Create(payload.Span); BufferState = (eof.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData; diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index dec7faea3..70dcc3d28 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -44,6 +44,7 @@ public ServerSession(ILogger logger, ConnectionPool? pool, int poolGeneration, i m_activityTags = []; DataReader = new(); Log.CreatedNewSession(m_logger, Id); + Context = new Context(0, null, 0); } public string Id { get; } @@ -51,22 +52,19 @@ public ServerSession(ILogger logger, ConnectionPool? pool, int poolGeneration, i public bool SupportsPerQueryVariables => ServerVersion.IsMariaDb && ServerVersion.Version >= ServerVersions.MariaDbSupportsPerQueryVariables; public int ActiveCommandId { get; private set; } public int CancellationTimeout { get; private set; } - public int ConnectionId { get; set; } public byte[]? AuthPluginData { get; set; } public long CreatedTimestamp { get; } public ConnectionPool? Pool { get; } public int PoolGeneration { get; } public long LastLeasedTimestamp { get; set; } public long LastReturnedTimestamp { get; private set; } - public string? DatabaseOverride { get; set; } + public string HostName { get; private set; } public IPEndPoint? IPEndPoint => m_tcpClient?.Client.RemoteEndPoint as IPEndPoint; public string? UserID { get; private set; } public WeakReference? OwningConnection { get; set; } - public bool SupportsDeprecateEof { get; private set; } - public bool SupportsCachedPreparedMetadata { get; private set; } - public bool SupportsQueryAttributes { get; private set; } - public bool SupportsSessionTrack { get; private set; } + public Context Context { get; private set; } + public bool ProcAccessDenied { get; set; } public ICollection> ActivityTags => m_activityTags; public MySqlDataReader DataReader { get; set; } @@ -241,7 +239,7 @@ public async Task PrepareAsync(IMySqlCommand command, IOBehavior ioBehavior, Can ColumnDefinitionPayload.Initialize(ref parameters[i], new(columnsAndParameters, columnsAndParametersSize, payloadLength)); columnsAndParametersSize += payloadLength; } - if (!SupportsDeprecateEof) + if (!Context.SupportsDeprecateEof) { payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); EofPayload.Create(payload.Span); @@ -261,7 +259,7 @@ public async Task PrepareAsync(IMySqlCommand command, IOBehavior ioBehavior, Can ColumnDefinitionPayload.Initialize(ref columns[i], new(columnsAndParameters, columnsAndParametersSize, payloadLength)); columnsAndParametersSize += payloadLength; } - if (!SupportsDeprecateEof) + if (!Context.SupportsDeprecateEof) { payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); EofPayload.Create(payload.Span); @@ -315,12 +313,12 @@ public void FinishQuerying() // In order to handle this case, we issue a dummy query that will consume the pending cancellation. // See https://bugs.mysql.com/bug.php?id=45679 Log.SendingSleepToClearPendingCancellation(m_logger, Id); - var payload = SupportsQueryAttributes ? s_sleepWithAttributesPayload : s_sleepNoAttributesPayload; + var payload = Context.SupportsQueryAttributes ? s_sleepWithAttributesPayload : s_sleepNoAttributesPayload; #pragma warning disable CA2012 // Safe because method completes synchronously SendAsync(payload, IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); payload = ReceiveReplyAsync(IOBehavior.Synchronous, CancellationToken.None).GetAwaiter().GetResult(); #pragma warning restore CA2012 - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); } lock (m_lock) @@ -340,8 +338,8 @@ public void FinishQuerying() var activity = ActivitySourceHelper.StartActivity(name, m_activityTags); if (activity is { IsAllDataRequested: true }) { - if (DatabaseOverride is not null) - activity.SetTag(ActivitySourceHelper.DatabaseNameTagName, DatabaseOverride); + if (!Context.IsInitialDatabase()) + activity.SetTag(ActivitySourceHelper.DatabaseNameTagName, Context.Database); if (tagName1 is not null) activity.SetTag(tagName1, tagValue1); } @@ -454,7 +452,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella } ServerVersion = new(initialHandshake.ServerVersion); - ConnectionId = initialHandshake.ConnectionId; + Context = new Context(initialHandshake.ProtocolCapabilities, cs.Database, initialHandshake.ConnectionId); AuthPluginData = initialHandshake.AuthPluginData; m_useCompression = cs.UseCompression && (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Compress) != 0; CancellationTimeout = cs.CancellationTimeout; @@ -462,22 +460,18 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella // set activity tags { - var connectionId = ConnectionId.ToString(CultureInfo.InvariantCulture); + var connectionId = Context.ConnectionId.ToString(CultureInfo.InvariantCulture); m_activityTags[ActivitySourceHelper.DatabaseConnectionIdTagName] = connectionId; if (activity is { IsAllDataRequested: true }) activity.SetTag(ActivitySourceHelper.DatabaseConnectionIdTagName, connectionId); } m_supportsConnectionAttributes = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.ConnectionAttributes) != 0; - SupportsDeprecateEof = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0; - SupportsCachedPreparedMetadata = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.MariaDbCacheMetadata) != 0; - SupportsQueryAttributes = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.QueryAttributes) != 0; - SupportsSessionTrack = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.SessionTrack) != 0; var serverSupportsSsl = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.Ssl) != 0; m_characterSet = ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? CharacterSet.Utf8Mb4GeneralCaseInsensitive : CharacterSet.Utf8Mb3GeneralCaseInsensitive; m_setNamesPayload = ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 ? - (SupportsQueryAttributes ? s_setNamesUtf8mb4WithAttributesPayload : s_setNamesUtf8mb4NoAttributesPayload) : - (SupportsQueryAttributes ? s_setNamesUtf8WithAttributesPayload : s_setNamesUtf8NoAttributesPayload); + (Context.SupportsQueryAttributes ? s_setNamesUtf8mb4WithAttributesPayload : s_setNamesUtf8mb4NoAttributesPayload) : + (Context.SupportsQueryAttributes ? s_setNamesUtf8WithAttributesPayload : s_setNamesUtf8NoAttributesPayload); // disable pipelining for RDS MySQL 5.7 (assuming Aurora); otherwise take it from the connection string or default to true if (!cs.Pipelining.HasValue && ServerVersion.Version.Major == 5 && ServerVersion.Version.Minor == 7 && HostName.EndsWith(".rds.amazonaws.com", StringComparison.OrdinalIgnoreCase)) @@ -505,7 +499,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella } } - Log.SessionMadeConnection(m_logger, Id, ServerVersion.OriginalString, ConnectionId, m_useCompression, m_supportsConnectionAttributes, SupportsDeprecateEof, SupportsCachedPreparedMetadata, serverSupportsSsl, SupportsSessionTrack, m_supportsPipelining, SupportsQueryAttributes); + Log.SessionMadeConnection(m_logger, Id, ServerVersion.OriginalString, Context.ConnectionId, m_useCompression, m_supportsConnectionAttributes, Context.SupportsDeprecateEof, Context.SupportsCachedPreparedMetadata, serverSupportsSsl, Context.SupportsSessionTrack, m_supportsPipelining, Context.SupportsQueryAttributes); if (cs.SslMode != MySqlSslMode.None && (cs.SslMode != MySqlSslMode.Preferred || serverSupportsSsl)) { @@ -532,16 +526,21 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella payload = await SwitchAuthenticationAsync(cs, password, payload, ioBehavior, cancellationToken).ConfigureAwait(false); } - var ok = OkPayload.Create(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + var ok = OkPayload.Create(payload.Span, Context); var statusInfo = ok.StatusInfo; if (m_useCompression) m_payloadHandler = new CompressedPayloadHandler(m_payloadHandler.ByteHandler); // set 'collation_connection' to the server default - await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false); - payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + if (Context.ClientCharset == null || ServerVersion.Version >= ServerVersions.SupportsUtf8Mb4 + ? !string.Equals(Context.ClientCharset, "utf8mb4", StringComparison.Ordinal) + : !string.Equals(Context.ClientCharset, "utf8", StringComparison.Ordinal)) + { + await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false); + payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); + OkPayload.Verify(payload.Span, Context); + } if (ShouldGetRealServerDetails(cs)) await GetRealServerDetailsAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); @@ -571,9 +570,9 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn ClearPreparedStatements(); PayloadData payload; - if (DatabaseOverride is null && - ((!ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) || - (ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.MariaDbSupportsResetConnection) >= 0))) + if (Context.IsInitialDatabase() && + ((!ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.SupportsResetConnection) >= 0) || + (ServerVersion.IsMariaDb && ServerVersion.Version.CompareTo(ServerVersions.MariaDbSupportsResetConnection) >= 0))) { if (m_supportsPipelining) { @@ -584,10 +583,10 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn // read two OK replies payload = await ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); payload = await ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); return true; } @@ -595,19 +594,19 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn Log.SendingResetConnectionRequest(m_logger, Id, ServerVersion.OriginalString); await SendAsync(ResetConnectionPayload.Instance, ioBehavior, cancellationToken).ConfigureAwait(false); payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); } else { // optimistically hash the password with the challenge from the initial handshake (supported by MariaDB; doesn't appear to be supported by MySQL) - if (DatabaseOverride is null) + if (Context.IsInitialDatabase()) { Log.SendingChangeUserRequest(m_logger, Id, ServerVersion.OriginalString); } else { - Log.SendingChangeUserRequestDueToChangedDatabase(m_logger, Id, DatabaseOverride); - DatabaseOverride = null; + Log.SendingChangeUserRequestDueToChangedDatabase(m_logger, Id, Context.Database!); + Context.Database = cs.Database; } var password = GetPassword(cs, connection); var hashedPassword = AuthenticationUtility.CreateAuthenticationResponse(AuthPluginData!, password); @@ -619,13 +618,13 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn Log.OptimisticReauthenticationFailed(m_logger, Id); payload = await SwitchAuthenticationAsync(cs, password, payload, ioBehavior, cancellationToken).ConfigureAwait(false); } - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); } // set 'collation_connection' to the server default await SendAsync(m_setNamesPayload, ioBehavior, cancellationToken).ConfigureAwait(false); payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); return true; } @@ -684,7 +683,7 @@ private async Task SwitchAuthenticationAsync(ConnectionSettings cs, payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); // OK payload can be sent immediately (e.g., if password is empty) bypassing even the fast authentication path - if (OkPayload.IsOk(payload.Span, SupportsDeprecateEof)) + if (OkPayload.IsOk(payload.Span, Context)) return payload; var cachingSha2ServerResponsePayload = CachingSha2ServerResponsePayload.Create(payload.Span); @@ -824,7 +823,7 @@ public async ValueTask TryPingAsync(bool logInfo, IOBehavior ioBehavior, C Log.PingingServer(m_logger, Id); await SendAsync(PingPayload.Instance, ioBehavior, cancellationToken).ConfigureAwait(false); var payload = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + OkPayload.Verify(payload.Span, Context); Log.SuccessfullyPingedServer(m_logger, logInfo ? LogLevel.Information : LogLevel.Trace, Id); return true; } @@ -1632,7 +1631,7 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation Log.DetectedProxy(m_logger, Id); try { - var payload = SupportsQueryAttributes ? s_selectConnectionIdVersionWithAttributesPayload : s_selectConnectionIdVersionNoAttributesPayload; + var payload = Context.SupportsQueryAttributes ? s_selectConnectionIdVersionWithAttributesPayload : s_selectConnectionIdVersionNoAttributesPayload; await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); // column count: 2 @@ -1642,7 +1641,7 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation _ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); _ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); - if (!SupportsDeprecateEof) + if (!Context.SupportsDeprecateEof) { payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); _ = EofPayload.Create(payload.Span); @@ -1662,15 +1661,15 @@ static void ReadRow(ReadOnlySpan span, out int? connectionId, out ServerVe // OK/EOF payload payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); - if (OkPayload.IsOk(payload.Span, SupportsDeprecateEof)) - OkPayload.Verify(payload.Span, SupportsDeprecateEof, SupportsSessionTrack); + if (OkPayload.IsOk(payload.Span, Context)) + OkPayload.Verify(payload.Span, Context); else EofPayload.Create(payload.Span); if (connectionId is int newConnectionId && serverVersion is not null) { - Log.ChangingConnectionId(m_logger, Id, ConnectionId, newConnectionId, ServerVersion.OriginalString, serverVersion.OriginalString); - ConnectionId = newConnectionId; + Log.ChangingConnectionId(m_logger, Id, Context.ConnectionId, newConnectionId, ServerVersion.OriginalString, serverVersion.OriginalString); + Context.ConnectionId = newConnectionId; ServerVersion = serverVersion; } } diff --git a/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs b/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs index e6c2c641d..9da955d32 100644 --- a/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs +++ b/src/MySqlConnector/Core/SingleCommandPayloadCreator.cs @@ -24,7 +24,7 @@ public bool WriteQueryCommand(ref CommandListPosition commandListPosition, IDict Log.PreparingCommandPayload(command.Logger, command.Connection!.Session.Id, command.CommandText!); writer.Write((byte) CommandKind.Query); - var supportsQueryAttributes = command.Connection!.Session.SupportsQueryAttributes; + var supportsQueryAttributes = command.Connection!.Session.Context.SupportsQueryAttributes; if (supportsQueryAttributes) { // attribute count @@ -83,7 +83,7 @@ private static void WritePreparedStatement(IMySqlCommand command, PreparedStatem Log.PreparingCommandPayloadWithId(command.Logger, command.Connection!.Session.Id, preparedStatement.StatementId, command.CommandText!); var attributes = command.RawAttributes; - var supportsQueryAttributes = command.Connection!.Session.SupportsQueryAttributes; + var supportsQueryAttributes = command.Connection!.Session.Context.SupportsQueryAttributes; writer.Write(preparedStatement.StatementId); // NOTE: documentation is not updated yet, but due to bugs in MySQL Server 8.0.23-8.0.25, the PARAMETER_COUNT_AVAILABLE (0x08) diff --git a/src/MySqlConnector/Logging/Log.cs b/src/MySqlConnector/Logging/Log.cs index 7cb3bd432..24421abd1 100644 --- a/src/MySqlConnector/Logging/Log.cs +++ b/src/MySqlConnector/Logging/Log.cs @@ -47,7 +47,7 @@ internal static partial class Log public static partial void AutoDetectedAurora57(ILogger logger, string sessionId, string hostName); [LoggerMessage(EventIds.SessionMadeConnection, LogLevel.Debug, "Session {SessionId} made connection; server version {ServerVersion}; connection ID {ConnectionId}; supports: compression {SupportsCompression}, attributes {SupportsAttributes}, deprecate EOF {SupportsDeprecateEof}, cached metadata {SupportsCachedMetadata}, SSL {SupportsSsl}, session track {SupportsSessionTrack}, pipelining {SupportsPipelining}, query attributes {SupportsQueryAttributes}")] - public static partial void SessionMadeConnection(ILogger logger, string sessionId, string serverVersion, int connectionId, bool supportsCompression, bool supportsAttributes, bool supportsDeprecateEof, bool supportsCachedMetadata, bool supportsSsl, bool supportsSessionTrack, bool supportsPipelining, bool supportsQueryAttributes); + public static partial void SessionMadeConnection(ILogger logger, string sessionId, string serverVersion, long connectionId, bool supportsCompression, bool supportsAttributes, bool supportsDeprecateEof, bool supportsCachedMetadata, bool supportsSsl, bool supportsSessionTrack, bool supportsPipelining, bool supportsQueryAttributes); [LoggerMessage(EventIds.ServerDoesNotSupportSsl, LogLevel.Error, "Session {SessionId} requires SSL but server doesn't support it")] public static partial void ServerDoesNotSupportSsl(ILogger logger, string sessionId); @@ -184,7 +184,7 @@ internal static partial class Log public static partial void DetectedProxy(ILogger logger, string sessionId); [LoggerMessage(EventIds.ChangingConnectionId, LogLevel.Debug, "Session {SessionId} changing connection id from {OldConnectionId} to {ConnectionId} and server version from {OldServerVersion} to {ServerVersion}")] - public static partial void ChangingConnectionId(ILogger logger, string sessionId, int oldConnectionId, int connectionId, string oldServerVersion, string serverVersion); + public static partial void ChangingConnectionId(ILogger logger, string sessionId, long oldConnectionId, long connectionId, string oldServerVersion, string serverVersion); [LoggerMessage(EventIds.FailedToGetConnectionId, LogLevel.Information, "Session {SessionId} failed to get CONNECTION_ID(), VERSION()")] public static partial void FailedToGetConnectionId(ILogger logger, Exception exception, string sessionId); diff --git a/src/MySqlConnector/MySqlConnection.cs b/src/MySqlConnector/MySqlConnection.cs index e9b21042e..13b297d8c 100644 --- a/src/MySqlConnector/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlConnection.cs @@ -152,7 +152,7 @@ private async ValueTask BeginTransactionAsync(IsolationLevel i Log.StartingTransaction(m_transactionLogger, m_session!.Id); // get the bytes for both payloads concatenated together (suitable for pipelining) - var startTransactionPayload = GetStartTransactionPayload(isolationLevel, isReadOnly, m_session.SupportsQueryAttributes); + var startTransactionPayload = GetStartTransactionPayload(isolationLevel, isReadOnly, m_session.Context.SupportsQueryAttributes); if (GetInitializedConnectionSettings() is { UseCompression: false, Pipelining: not false }) { @@ -161,10 +161,10 @@ private async ValueTask BeginTransactionAsync(IsolationLevel i // read the two OK replies var payload = await m_session.ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session.Context); payload = await m_session.ReceiveReplyAsync(1, ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session.Context); } else { @@ -172,12 +172,12 @@ private async ValueTask BeginTransactionAsync(IsolationLevel i await m_session.SendAsync(new Protocol.PayloadData(startTransactionPayload.Slice(4, startTransactionPayload.Span[0])), ioBehavior, cancellationToken).ConfigureAwait(false); var payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session.Context); await m_session.SendAsync(new Protocol.PayloadData(startTransactionPayload.Slice(8 + startTransactionPayload.Span[0], startTransactionPayload.Span[startTransactionPayload.Span[0] + 4])), ioBehavior, cancellationToken).ConfigureAwait(false); payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, m_session.Context); } var transaction = new MySqlTransaction(this, isolationLevel, m_transactionLogger); @@ -487,8 +487,10 @@ private async Task ChangeDatabaseAsync(IOBehavior ioBehavior, string databaseNam using (var initDatabasePayload = InitDatabasePayload.Create(databaseName)) await m_session!.SendAsync(initDatabasePayload, ioBehavior, cancellationToken).ConfigureAwait(false); var payload = await m_session.ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, m_session.SupportsDeprecateEof, m_session.SupportsSessionTrack); - m_session.DatabaseOverride = databaseName; + OkPayload.Verify(payload.Span, m_session.Context); + + // for non session tracking servers + m_session.Context.Database = databaseName; } public new MySqlCommand CreateCommand() => (MySqlCommand) base.CreateCommand(); @@ -603,7 +605,7 @@ public async ValueTask ResetConnectionAsync(CancellationToken cancellationToken Log.ResettingConnection(m_logger, session.Id); await session.SendAsync(ResetConnectionPayload.Instance, AsyncIOBehavior, cancellationToken).ConfigureAwait(false); var payload = await session.ReceiveReplyAsync(AsyncIOBehavior, cancellationToken).ConfigureAwait(false); - OkPayload.Verify(payload.Span, session.SupportsDeprecateEof, session.SupportsSessionTrack); + OkPayload.Verify(payload.Span, session.Context); } [AllowNull] @@ -626,7 +628,7 @@ public override string ConnectionString } } - public override string Database => m_session?.DatabaseOverride ?? GetConnectionSettings().Database; + public override string Database => m_session?.Context.Database ?? GetConnectionSettings().Database; public override ConnectionState State => m_connectionState; @@ -637,7 +639,7 @@ public override string ConnectionString /// /// The connection ID from MySQL Server. /// - public int ServerThread => Session.ConnectionId; + public int ServerThread => Session.Context.ConnectionId; /// /// Gets or sets the delegate used to provide client certificates for connecting to a server. diff --git a/src/MySqlConnector/Protocol/Payloads/OkPayload.cs b/src/MySqlConnector/Protocol/Payloads/OkPayload.cs index d5d4a1aa2..6b9e15093 100644 --- a/src/MySqlConnector/Protocol/Payloads/OkPayload.cs +++ b/src/MySqlConnector/Protocol/Payloads/OkPayload.cs @@ -1,4 +1,6 @@ +using System.Globalization; using System.Text; +using MySqlConnector.Core; using MySqlConnector.Protocol.Serialization; using MySqlConnector.Utilities; @@ -11,7 +13,6 @@ internal sealed class OkPayload public ServerStatus ServerStatus { get; } public int WarningCount { get; } public string? StatusInfo { get; } - public string? NewSchema { get; } public const byte Signature = 0x00; @@ -20,56 +21,51 @@ internal sealed class OkPayload * https://mariadb.com/kb/en/the-mariadb-library/resultset/ * https://github.com/MariaDB/mariadb-connector-j/blob/5fa814ac6e1b4c9cb6d141bd221cbd5fc45c8a78/src/main/java/org/mariadb/jdbc/internal/com/read/resultset/SelectResultSet.java#L443-L444 */ - public static bool IsOk(ReadOnlySpan span, bool deprecateEof) => + public static bool IsOk(ReadOnlySpan span, Context context) => span.Length > 0 && (span.Length > 6 && span[0] == Signature || - deprecateEof && span.Length < 0xFF_FFFF && span[0] == EofPayload.Signature); + context.SupportsDeprecateEof && span.Length < 0xFF_FFFF && span[0] == EofPayload.Signature); /// /// Creates an from the given , or throws /// if the bytes do not represent a valid . /// /// The bytes from which to read an OK packet. - /// Whether the flag was set on the connection. - /// Whether flag was set on the connection. + /// Current connection variables context /// A with the contents of the OK packet. /// Thrown when the bytes are not a valid OK packet. - public static OkPayload Create(ReadOnlySpan span, bool deprecateEof, bool clientSessionTrack) => - Read(span, deprecateEof, clientSessionTrack, true)!; + public static OkPayload Create(ReadOnlySpan span, Context context) => + Read(span, context, true)!; /// /// Verifies that the bytes in the given form a valid , or throws /// if they do not. /// /// The bytes from which to read an OK packet. - /// Whether the flag was set on the connection. - /// Whether flag was set on the connection. + /// Current connection variables context /// Thrown when the bytes are not a valid OK packet. - public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clientSessionTrack) => - Read(span, deprecateEof, clientSessionTrack, createPayload: false); + public static void Verify(ReadOnlySpan span, Context context) => + Read(span, context, createPayload: false); - private static OkPayload? Read(ReadOnlySpan span, bool deprecateEof, bool clientSessionTrack, bool createPayload) + private static OkPayload? Read(ReadOnlySpan span, Context context, bool createPayload) { var reader = new ByteArrayReader(span); var signature = reader.ReadByte(); - if (signature != Signature && (!deprecateEof || signature != EofPayload.Signature)) + if (signature != Signature && (!context.SupportsDeprecateEof || signature != EofPayload.Signature)) throw new FormatException($"Expected to read 0x00 or 0xFE but got 0x{signature:X2}"); var affectedRowCount = reader.ReadLengthEncodedInteger(); var lastInsertId = reader.ReadLengthEncodedInteger(); var serverStatus = (ServerStatus) reader.ReadUInt16(); var warningCount = (int) reader.ReadUInt16(); - string? newSchema = null; ReadOnlySpan statusBytes; - if (clientSessionTrack) + if (context.SupportsSessionTrack) { if (reader.BytesRemaining > 0) { statusBytes = reader.ReadLengthEncodedByteString(); // human-readable info - - if ((serverStatus & ServerStatus.SessionStateChanged) == ServerStatus.SessionStateChanged && reader.BytesRemaining > 0) + while (reader.BytesRemaining > 0) { - // implies ProtocolCapabilities.SessionTrack var sessionStateChangeDataLength = checked((int) reader.ReadLengthEncodedInteger()); var endOffset = reader.Offset + sessionStateChangeDataLength; while (reader.Offset < endOffset) @@ -79,7 +75,28 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien switch (kind) { case SessionTrackKind.Schema: - newSchema = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString()); + context.Database = Encoding.UTF8.GetString(reader.ReadLengthEncodedByteString()); + break; + + case SessionTrackKind.SystemVariables: + var systemVariableOffset = reader.Offset + dataLength; + do + { + var variableSv = Encoding.ASCII.GetString(reader.ReadLengthEncodedByteString()); + var lenSv = reader.ReadLengthEncodedIntegerOrNull(); + var valueSv = lenSv == -1 + ? null + : Encoding.ASCII.GetString(reader.ReadByteString(lenSv)); + switch (variableSv) + { + case "character_set_client": + context.ClientCharset = valueSv; + break; + case "connection_id": + context.ConnectionId = Convert.ToInt32(valueSv, CultureInfo.InvariantCulture); + break; + } + } while (reader.Offset < systemVariableOffset); break; default: @@ -109,7 +126,7 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien { var statusInfo = statusBytes.Length == 0 ? null : Encoding.UTF8.GetString(statusBytes); - if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null && newSchema is null) + if (affectedRowCount == 0 && lastInsertId == 0 && warningCount == 0 && statusInfo is null) { if (serverStatus == ServerStatus.AutoCommit) return s_autoCommitOk; @@ -117,7 +134,7 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien return s_autoCommitSessionStateChangedOk; } - return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo, newSchema); + return new OkPayload(affectedRowCount, lastInsertId, serverStatus, warningCount, statusInfo); } else { @@ -125,16 +142,15 @@ public static void Verify(ReadOnlySpan span, bool deprecateEof, bool clien } } - private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo, string? newSchema) + private OkPayload(ulong affectedRowCount, ulong lastInsertId, ServerStatus serverStatus, int warningCount, string? statusInfo) { AffectedRowCount = affectedRowCount; LastInsertId = lastInsertId; ServerStatus = serverStatus; WarningCount = warningCount; StatusInfo = statusInfo; - NewSchema = newSchema; } - private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, null, null); - private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, null, null); + private static readonly OkPayload s_autoCommitOk = new(0, 0, ServerStatus.AutoCommit, 0, null); + private static readonly OkPayload s_autoCommitSessionStateChangedOk = new(0, 0, ServerStatus.AutoCommit | ServerStatus.SessionStateChanged, 0, null); }