Skip to content

Commit

Permalink
fix describe metadata query portal and add test case for query isolat…
Browse files Browse the repository at this point in the history
…ion issue
  • Loading branch information
goldmedal committed Jan 29, 2024
1 parent c82f2f8 commit 45f47fe
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ private void handleExecute(ByteBuf buffer, Channel channel)
portal.getRowCount(),
resultFormatCodes);
portal.setRowCount(resultSetSender.sendResultSet());
// clean metadata query after executed
wireProtocolSession.removeMetadataQuery(portal.getPreparedStatement().getName(), portal.getName());
}
catch (Exception e) {
LOG.error(e, format("Execute query failed. Statement: %s. Root cause is %s", statement, e.getMessage()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
import static io.accio.main.wireprotocol.PostgresWireProtocol.isIgnoredCommand;
import static io.accio.main.wireprotocol.PostgresWireProtocolErrorCode.INVALID_PREPARED_STATEMENT_NAME;
import static io.accio.main.wireprotocol.PreparedStatement.RESERVED_DRY_RUN_NAME;
import static io.accio.main.wireprotocol.PreparedStatement.cloneWithName;
import static io.accio.main.wireprotocol.WireProtocolSession.PreparedStmtPortalName.preparedStmtPortalName;
import static io.accio.main.wireprotocol.patterns.PostgreSqlRewriteUtil.rewriteWithParameters;
import static io.trino.execution.ParameterExtractor.getParameterCount;
Expand Down Expand Up @@ -243,8 +242,7 @@ public List<Integer> describeStatement(String name)
*/
public Optional<List<Column>> dryRunAfterDescribeStatement(String statementName, List<Object> params, @Nullable FormatCodes.FormatCode[] resultFormatCodes)
{
preparedStatements.put(RESERVED_DRY_RUN_NAME, cloneWithName(preparedStatements.get(statementName), RESERVED_DRY_RUN_NAME));
metadataQueries.put(preparedStmtPortalName(RESERVED_DRY_RUN_NAME, null), Query.builder(preparedStatements.get(RESERVED_DRY_RUN_NAME)).build());
parse(RESERVED_DRY_RUN_NAME, preparedStatements.get(statementName).getOriginalStatement(), preparedStatements.get(statementName).getParamTypeOids());
bind(RESERVED_DRY_RUN_NAME, RESERVED_DRY_RUN_NAME, params, resultFormatCodes);

Optional<List<Column>> result = describePortal(RESERVED_DRY_RUN_NAME);
Expand Down Expand Up @@ -326,7 +324,7 @@ private void parseDataSourceQuery(String statementName, String statement, List<I
statementTrimmed,
isSessionCommand(rewrittenStatement),
QueryLevel.DATASOURCE));
// metadataQueries.remove(preparedStmtPortalName(statementName, null));
metadataQueries.remove(preparedStmtPortalName(statementName, null));
LOG.info("Create preparedStatement %s", statementName);
}

Expand Down Expand Up @@ -412,7 +410,7 @@ public void bind(String portalName, String statementName, List<Object> params, @
private void resetMetadataQuery(String statementName, String portalName)
{
PreparedStmtPortalName name = preparedStmtPortalName(statementName, portalName);
metadataQueries.get(name).getPortal().ifPresent(Portal::close);
Optional.ofNullable(metadataQueries.get(name)).flatMap(Query::getPortal).ifPresent(Portal::close);
metadataQueries.remove(name);
}

Expand Down Expand Up @@ -453,6 +451,11 @@ private Optional<ConnectorRecordIterator> executeCache(Portal portal)
});
}

public void removeMetadataQuery(String statementName, String portalName)
{
metadataQueries.remove(preparedStmtPortalName(statementName, portalName));
}

private CompletableFuture<Optional<Iterable<?>>> executeSessionCommand(Portal portal)
{
throw new UnsupportedOperationException();
Expand Down
71 changes: 71 additions & 0 deletions accio-tests/src/test/java/io/accio/testing/TestMetadataQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.assertj.core.api.AssertionsForClassTypes;
import org.testng.annotations.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.sql.Connection;
Expand Down Expand Up @@ -190,4 +191,74 @@ public void testExecuteAndDescribeLevel2Query()
protocolClient.assertCommandComplete("SELECT 1");
}
}

@Test
public void testQueryLevelIsolationIssueInOneTransaction()
throws IOException
{
try (TestingWireProtocolClient protocolClient = wireProtocolClient()) {
protocolClient.sendStartUpMessage(196608, MOCK_PASSWORD, "test", "canner");
protocolClient.assertAuthOk();
assertDefaultPgConfigResponse(protocolClient);
protocolClient.assertReadyForQuery('I');

// Execute level 1

List<PGType> paramTypes = ImmutableList.of(INTEGER);
protocolClient.sendParse("", "select typname from pg_type where pg_type.oid = ?",
paramTypes.stream().map(PGType::oid).collect(toImmutableList()));
protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.STATEMENT, "");
protocolClient.sendBind("", "", ImmutableList.of(textParameter(14, INTEGER)));
protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.PORTAL, "");
protocolClient.sendExecute("", 0);

protocolClient.assertParseComplete();

List<PGType<?>> actualParamTypes = protocolClient.assertAndGetParameterDescription();
AssertionsForClassTypes.assertThat(actualParamTypes).isEqualTo(paramTypes);

List<TestingWireProtocolClient.Field> fields = protocolClient.assertAndGetRowDescriptionFields();
List<PGType> actualTypes = fields.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList());
AssertionsForClassTypes.assertThat(actualTypes).isEqualTo(ImmutableList.of(VARCHAR));

protocolClient.assertBindComplete();

List<TestingWireProtocolClient.Field> fields2 = protocolClient.assertAndGetRowDescriptionFields();
List<PGType> actualTypes2 = fields2.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList());
AssertionsForClassTypes.assertThat(actualTypes2).isEqualTo(ImmutableList.of(VARCHAR));

protocolClient.assertDataRow("bigint");
protocolClient.assertCommandComplete("SELECT 1");

// Execute Level 3
paramTypes = ImmutableList.of(INTEGER);
protocolClient.sendParse("", "select * from (values ('rows1', 10), ('rows2', 10)) as t(col1, col2) where col2 = ?",
paramTypes.stream().map(PGType::oid).collect(toImmutableList()));
protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.STATEMENT, "");
protocolClient.sendBind("", "", ImmutableList.of(textParameter(10, INTEGER)));
protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.PORTAL, "");
protocolClient.sendExecute("", 0);
protocolClient.sendSync();

protocolClient.assertParseComplete();

actualParamTypes = protocolClient.assertAndGetParameterDescription();
AssertionsForClassTypes.assertThat(actualParamTypes).isEqualTo(paramTypes);

fields = protocolClient.assertAndGetRowDescriptionFields();
actualTypes = fields.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList());
AssertionsForClassTypes.assertThat(actualTypes).isEqualTo(ImmutableList.of(VARCHAR, INTEGER));

protocolClient.assertBindComplete();

fields2 = protocolClient.assertAndGetRowDescriptionFields();
actualTypes2 = fields2.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList());
AssertionsForClassTypes.assertThat(actualTypes2).isEqualTo(ImmutableList.of(VARCHAR, INTEGER));

protocolClient.assertDataRow("rows1,10");
protocolClient.assertDataRow("rows2,10");
protocolClient.assertCommandComplete("SELECT 2");
protocolClient.assertReadyForQuery('I');
}
}
}

0 comments on commit 45f47fe

Please sign in to comment.