diff --git a/accio-main/src/main/java/io/accio/main/wireprotocol/PostgresWireProtocol.java b/accio-main/src/main/java/io/accio/main/wireprotocol/PostgresWireProtocol.java index 7c4684eea..eadd078d8 100644 --- a/accio-main/src/main/java/io/accio/main/wireprotocol/PostgresWireProtocol.java +++ b/accio-main/src/main/java/io/accio/main/wireprotocol/PostgresWireProtocol.java @@ -469,6 +469,8 @@ private void handleExecute(ByteBuf buffer, Channel channel) portal.getRowCount(), resultFormatCodes); portal.setRowCount(resultSetSender.sendResultSet()); + // clean metadata query after executed + wireProtocolSession.removeMetadataQuery(portal.getPreparedStatement().getName(), portal.getName()); } catch (Exception e) { LOG.error(e, format("Execute query failed. Statement: %s. Root cause is %s", statement, e.getMessage())); diff --git a/accio-main/src/main/java/io/accio/main/wireprotocol/WireProtocolSession.java b/accio-main/src/main/java/io/accio/main/wireprotocol/WireProtocolSession.java index 3ccf4f2b4..a5da663f1 100644 --- a/accio-main/src/main/java/io/accio/main/wireprotocol/WireProtocolSession.java +++ b/accio-main/src/main/java/io/accio/main/wireprotocol/WireProtocolSession.java @@ -61,7 +61,6 @@ import static io.accio.main.wireprotocol.PostgresWireProtocol.isIgnoredCommand; import static io.accio.main.wireprotocol.PostgresWireProtocolErrorCode.INVALID_PREPARED_STATEMENT_NAME; import static io.accio.main.wireprotocol.PreparedStatement.RESERVED_DRY_RUN_NAME; -import static io.accio.main.wireprotocol.PreparedStatement.cloneWithName; import static io.accio.main.wireprotocol.WireProtocolSession.PreparedStmtPortalName.preparedStmtPortalName; import static io.accio.main.wireprotocol.patterns.PostgreSqlRewriteUtil.rewriteWithParameters; import static io.trino.execution.ParameterExtractor.getParameterCount; @@ -243,8 +242,7 @@ public List describeStatement(String name) */ public Optional> dryRunAfterDescribeStatement(String statementName, List params, @Nullable FormatCodes.FormatCode[] resultFormatCodes) { - preparedStatements.put(RESERVED_DRY_RUN_NAME, cloneWithName(preparedStatements.get(statementName), RESERVED_DRY_RUN_NAME)); - metadataQueries.put(preparedStmtPortalName(RESERVED_DRY_RUN_NAME, null), Query.builder(preparedStatements.get(RESERVED_DRY_RUN_NAME)).build()); + parse(RESERVED_DRY_RUN_NAME, preparedStatements.get(statementName).getOriginalStatement(), preparedStatements.get(statementName).getParamTypeOids()); bind(RESERVED_DRY_RUN_NAME, RESERVED_DRY_RUN_NAME, params, resultFormatCodes); Optional> result = describePortal(RESERVED_DRY_RUN_NAME); @@ -326,7 +324,7 @@ private void parseDataSourceQuery(String statementName, String statement, List params, @ private void resetMetadataQuery(String statementName, String portalName) { PreparedStmtPortalName name = preparedStmtPortalName(statementName, portalName); - metadataQueries.get(name).getPortal().ifPresent(Portal::close); + Optional.ofNullable(metadataQueries.get(name)).flatMap(Query::getPortal).ifPresent(Portal::close); metadataQueries.remove(name); } @@ -453,6 +451,11 @@ private Optional executeCache(Portal portal) }); } + public void removeMetadataQuery(String statementName, String portalName) + { + metadataQueries.remove(preparedStmtPortalName(statementName, portalName)); + } + private CompletableFuture>> executeSessionCommand(Portal portal) { throw new UnsupportedOperationException(); diff --git a/accio-tests/src/test/java/io/accio/testing/TestMetadataQuery.java b/accio-tests/src/test/java/io/accio/testing/TestMetadataQuery.java index 7ef5baa14..1ca39f841 100644 --- a/accio-tests/src/test/java/io/accio/testing/TestMetadataQuery.java +++ b/accio-tests/src/test/java/io/accio/testing/TestMetadataQuery.java @@ -23,6 +23,7 @@ import org.assertj.core.api.AssertionsForClassTypes; import org.testng.annotations.Test; +import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.sql.Connection; @@ -190,4 +191,74 @@ public void testExecuteAndDescribeLevel2Query() protocolClient.assertCommandComplete("SELECT 1"); } } + + @Test + public void testQueryLevelIsolationIssueInOneTransaction() + throws IOException + { + try (TestingWireProtocolClient protocolClient = wireProtocolClient()) { + protocolClient.sendStartUpMessage(196608, MOCK_PASSWORD, "test", "canner"); + protocolClient.assertAuthOk(); + assertDefaultPgConfigResponse(protocolClient); + protocolClient.assertReadyForQuery('I'); + + // Execute level 1 + + List paramTypes = ImmutableList.of(INTEGER); + protocolClient.sendParse("", "select typname from pg_type where pg_type.oid = ?", + paramTypes.stream().map(PGType::oid).collect(toImmutableList())); + protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.STATEMENT, ""); + protocolClient.sendBind("", "", ImmutableList.of(textParameter(14, INTEGER))); + protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.PORTAL, ""); + protocolClient.sendExecute("", 0); + + protocolClient.assertParseComplete(); + + List> actualParamTypes = protocolClient.assertAndGetParameterDescription(); + AssertionsForClassTypes.assertThat(actualParamTypes).isEqualTo(paramTypes); + + List fields = protocolClient.assertAndGetRowDescriptionFields(); + List actualTypes = fields.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList()); + AssertionsForClassTypes.assertThat(actualTypes).isEqualTo(ImmutableList.of(VARCHAR)); + + protocolClient.assertBindComplete(); + + List fields2 = protocolClient.assertAndGetRowDescriptionFields(); + List actualTypes2 = fields2.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList()); + AssertionsForClassTypes.assertThat(actualTypes2).isEqualTo(ImmutableList.of(VARCHAR)); + + protocolClient.assertDataRow("bigint"); + protocolClient.assertCommandComplete("SELECT 1"); + + // Execute Level 3 + paramTypes = ImmutableList.of(INTEGER); + protocolClient.sendParse("", "select * from (values ('rows1', 10), ('rows2', 10)) as t(col1, col2) where col2 = ?", + paramTypes.stream().map(PGType::oid).collect(toImmutableList())); + protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.STATEMENT, ""); + protocolClient.sendBind("", "", ImmutableList.of(textParameter(10, INTEGER))); + protocolClient.sendDescribe(TestingWireProtocolClient.DescribeType.PORTAL, ""); + protocolClient.sendExecute("", 0); + protocolClient.sendSync(); + + protocolClient.assertParseComplete(); + + actualParamTypes = protocolClient.assertAndGetParameterDescription(); + AssertionsForClassTypes.assertThat(actualParamTypes).isEqualTo(paramTypes); + + fields = protocolClient.assertAndGetRowDescriptionFields(); + actualTypes = fields.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList()); + AssertionsForClassTypes.assertThat(actualTypes).isEqualTo(ImmutableList.of(VARCHAR, INTEGER)); + + protocolClient.assertBindComplete(); + + fields2 = protocolClient.assertAndGetRowDescriptionFields(); + actualTypes2 = fields2.stream().map(TestingWireProtocolClient.Field::getTypeId).map(PGTypes::oidToPgType).collect(toImmutableList()); + AssertionsForClassTypes.assertThat(actualTypes2).isEqualTo(ImmutableList.of(VARCHAR, INTEGER)); + + protocolClient.assertDataRow("rows1,10"); + protocolClient.assertDataRow("rows2,10"); + protocolClient.assertCommandComplete("SELECT 2"); + protocolClient.assertReadyForQuery('I'); + } + } }