From e88e2b133fec6c87c8bc1b9805d8e4c8627a21d7 Mon Sep 17 00:00:00 2001 From: chenjian2664 Date: Tue, 10 Sep 2024 12:33:59 +0800 Subject: [PATCH] Support merge in Postgresql connector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Grzegorz KokosiƄski <7569403+kokosing@users.noreply.github.com> --- docs/src/main/sphinx/connector/postgresql.md | 12 ++ .../io/trino/plugin/jdbc/BaseJdbcClient.java | 6 + .../trino/plugin/jdbc/CachingJdbcClient.java | 6 + .../plugin/jdbc/DefaultJdbcMetadata.java | 16 ++- .../plugin/jdbc/ForwardingJdbcClient.java | 6 + .../java/io/trino/plugin/jdbc/JdbcClient.java | 2 + .../io/trino/plugin/jdbc/JdbcWriteConfig.java | 18 +++ .../jdbc/JdbcWriteSessionProperties.java | 13 +- .../trino/plugin/jdbc/RetryingJdbcClient.java | 7 + .../jdbc/jmx/StatisticsAwareJdbcClient.java | 6 + .../plugin/jdbc/BaseJdbcConnectorTest.java | 18 +-- .../plugin/jdbc/TestJdbcWriteConfig.java | 5 +- .../plugin/hive/BaseHiveConnectorTest.java | 4 +- .../plugin/kudu/TestKuduConnectorTest.java | 44 +++--- .../trino/plugin/phoenix5/PhoenixClient.java | 6 + .../phoenix5/TestPhoenixConnectorTest.java | 12 +- .../plugin/postgresql/PostgreSqlClient.java | 35 +++++ .../BasePostgresFailureRecoveryTest.java | 28 +++- .../TestPostgreSqlConnectorTest.java | 56 +++++++- .../TestPostgreSqlJdbcConnectionAccesses.java | 43 +++++- .../TestPostgreSqlJdbcConnectionCreation.java | 36 ++++- .../TestRemoteQueryCommentLogging.java | 3 +- .../io/trino/testing/BaseConnectorTest.java | 126 ++++++++++-------- .../testing/BaseFailureRecoveryTest.java | 41 +++++- 24 files changed, 439 insertions(+), 110 deletions(-) diff --git a/docs/src/main/sphinx/connector/postgresql.md b/docs/src/main/sphinx/connector/postgresql.md index 809ea463c6b0..62d3d7789518 100644 --- a/docs/src/main/sphinx/connector/postgresql.md +++ b/docs/src/main/sphinx/connector/postgresql.md @@ -115,6 +115,18 @@ catalog named `sales` using the configured connector. ```{include} non-transactional-insert.fragment ``` +### Non-transactional MERGE + +The connector supports adding rows using {doc}`MERGE statements `. +However, the connector only support merge modifying directly to the target +table at current, to use merge you need to set the `merge.non-transactional-merge.enabled` +catalog property or the corresponding `non_transactional_merge_enabled` catalog session property to +`true`. + +Note that with this property enabled, data can be corrupted in rare cases where +exceptions occur during the merge operation. With transactions disabled, no +rollback can be performed. + (postgresql-fte-support)= ### Fault-tolerant execution support diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java index db7a09ffedc7..b866c1b5b19d 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/BaseJdbcClient.java @@ -1584,6 +1584,12 @@ public boolean isLimitGuaranteed(ConnectorSession session) throw new TrinoException(JDBC_ERROR, "limitFunction() is implemented without isLimitGuaranteed()"); } + @Override + public boolean supportsMerge() + { + return false; + } + @Override public String quoted(String name) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java index 64beeea124a4..25b6dcc177aa 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/CachingJdbcClient.java @@ -354,6 +354,12 @@ public boolean isLimitGuaranteed(ConnectorSession session) return delegate.isLimitGuaranteed(session); } + @Override + public boolean supportsMerge() + { + return delegate.supportsMerge(); + } + @Override public Optional getTableHandle(ConnectorSession session, SchemaTableName schemaTableName) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index c76a0a62b9c9..9fd33112b2e1 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -114,6 +114,7 @@ import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isJoinPushdownEnabled; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.isTopNPushdownEnabled; import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalInsert; +import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.isNonTransactionalMerge; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.connector.RetryMode.NO_RETRIES; import static io.trino.spi.connector.RowChangeParadigm.CHANGE_ONLY_UPDATED_COLUMNS; @@ -1297,11 +1298,24 @@ public ColumnHandle getMergeRowIdColumnHandle(ConnectorSession session, Connecto @Override public ConnectorMergeTableHandle beginMerge(ConnectorSession session, ConnectorTableHandle tableHandle, Map> updateColumnHandles, RetryMode retryMode) { + if (retryMode != NO_RETRIES) { + throw new TrinoException(NOT_SUPPORTED, "This connector does not support MERGE with fault-tolerant execution"); + } + + if (!jdbcClient.supportsMerge()) { + throw new TrinoException(NOT_SUPPORTED, MODIFYING_ROWS_MESSAGE); + } + + if (!isNonTransactionalMerge(session)) { + throw new TrinoException(NOT_SUPPORTED, "Non-transactional MERGE is disabled"); + } + JdbcTableHandle handle = (JdbcTableHandle) tableHandle; checkArgument(handle.isNamedRelation(), "Merge target must be named relation table"); + List primaryKeys = jdbcClient.getPrimaryKeys(session, handle.getRequiredNamedRelation().getRemoteTableName()); if (primaryKeys.isEmpty()) { - throw new TrinoException(NOT_SUPPORTED, MODIFYING_ROWS_MESSAGE); + throw new TrinoException(NOT_SUPPORTED, "The connector can not perform merge on the target table without primary keys"); } SchemaTableName schemaTableName = handle.getRequiredNamedRelation().getSchemaTableName(); diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java index 9c1829199074..2a8f75b3e506 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/ForwardingJdbcClient.java @@ -353,6 +353,12 @@ public boolean isLimitGuaranteed(ConnectorSession session) return delegate().isLimitGuaranteed(session); } + @Override + public boolean supportsMerge() + { + return delegate().supportsMerge(); + } + @Override public Optional getTableComment(ResultSet resultSet) throws SQLException diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java index bd8d07e9748a..001dec0827ee 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcClient.java @@ -166,6 +166,8 @@ Optional legacyImplementJoin( boolean isLimitGuaranteed(ConnectorSession session); + boolean supportsMerge(); + default Optional getTableComment(ResultSet resultSet) throws SQLException { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java index bc042e1e0792..6b1c5e3dc3bf 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteConfig.java @@ -30,6 +30,10 @@ public class JdbcWriteConfig // This means that the write operation can fail and leave the table in an inconsistent state. private boolean nonTransactionalInsert; + // Do not create temporary table during merge. + // This means that the write operation can fail and leave the table in an inconsistent state. + private boolean nonTransactionalMerge; + @Min(1) @Max(MAX_ALLOWED_WRITE_BATCH_SIZE) public int getWriteBatchSize() @@ -59,6 +63,20 @@ public JdbcWriteConfig setNonTransactionalInsert(boolean nonTransactionalInsert) return this; } + public boolean isNonTransactionalMerge() + { + return nonTransactionalMerge; + } + + @Config("merge.non-transactional-merge.enabled") + @ConfigDescription("Enables support for non-transactional MERGE. " + + "This means that the write operation can fail and leave the table in an inconsistent state.") + public JdbcWriteConfig setNonTransactionalMerge(boolean nonTransactionalMerge) + { + this.nonTransactionalMerge = nonTransactionalMerge; + return this; + } + @Min(1) @Max(128) public int getWriteParallelism() diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java index 78e6d12d2298..04e5555c8590 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/JdbcWriteSessionProperties.java @@ -33,6 +33,7 @@ public class JdbcWriteSessionProperties { public static final String WRITE_BATCH_SIZE = "write_batch_size"; public static final String NON_TRANSACTIONAL_INSERT = "non_transactional_insert"; + public static final String NON_TRANSACTIONAL_MERGE = "non_transactional_merge"; public static final String WRITE_PARALLELISM = "write_parallelism"; private final List> properties; @@ -49,9 +50,14 @@ public JdbcWriteSessionProperties(JdbcWriteConfig writeConfig) false)) .add(booleanProperty( NON_TRANSACTIONAL_INSERT, - "Do not use temporary table on insert to table", + "Enables support for non-transactional MERGE", writeConfig.isNonTransactionalInsert(), false)) + .add(booleanProperty( + NON_TRANSACTIONAL_MERGE, + "Do not use temporary table on merge", + writeConfig.isNonTransactionalMerge(), + false)) .add(integerProperty( WRITE_PARALLELISM, "Maximum number of parallel write tasks", @@ -81,6 +87,11 @@ public static boolean isNonTransactionalInsert(ConnectorSession session) return session.getProperty(NON_TRANSACTIONAL_INSERT, Boolean.class); } + public static boolean isNonTransactionalMerge(ConnectorSession session) + { + return session.getProperty(NON_TRANSACTIONAL_MERGE, Boolean.class); + } + private static void validateWriteBatchSize(int maxBatchSize) { if (maxBatchSize < 1) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingJdbcClient.java index 80607e3a52cc..0c59706f254b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RetryingJdbcClient.java @@ -272,6 +272,13 @@ public boolean isLimitGuaranteed(ConnectorSession session) return delegate.isLimitGuaranteed(session); } + @Override + public boolean supportsMerge() + { + // there should be no remote database interaction + return delegate.supportsMerge(); + } + @Override public Optional getTableComment(ResultSet resultSet) throws SQLException diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java index 18def16e4444..24690d0519cd 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/jmx/StatisticsAwareJdbcClient.java @@ -440,6 +440,12 @@ public boolean isLimitGuaranteed(ConnectorSession session) return delegate().isLimitGuaranteed(session); } + @Override + public boolean supportsMerge() + { + return delegate().supportsMerge(); + } + @Override public void createSchema(ConnectorSession session, String schemaName) { diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 1f45eb0e287b..59e9c9d75f5e 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -1875,8 +1875,8 @@ public void testConstantUpdateWithVarcharEqualityPredicates() public void testConstantUpdateWithVarcharInequalityPredicates() { skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE)); - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"))) { - if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY)) { + try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) { + if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) { assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", MODIFYING_ROWS_MESSAGE); return; } @@ -1890,8 +1890,8 @@ public void testConstantUpdateWithVarcharInequalityPredicates() public void testConstantUpdateWithVarcharGreaterAndLowerPredicate() { skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_UPDATE)); - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"))) { - if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY)) { + try (TestTable table = createTestTableForWrites("test_update_varchar", "(col1 INT, col2 varchar(1))", ImmutableList.of("1, 'a'", "2, 'A'"), "col2")) { + if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) { assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 > 'A'", MODIFYING_ROWS_MESSAGE); assertQueryFails("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'A'", MODIFYING_ROWS_MESSAGE); return; @@ -1943,14 +1943,14 @@ public void testDeleteWithVarcharInequalityPredicate() { skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_ROW_LEVEL_DELETE)); // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_varchar", "(col varchar(1))", ImmutableList.of("'a'", "'A'", "null"))) { + try (TestTable table = createTestTableForWrites("test_delete_varchar", "(col varchar(1), pk int)", ImmutableList.of("'a', 0", "'A', 1", "null, 2"), "pk")) { if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_MERGE)) { assertQueryFails("DELETE FROM " + table.getName() + " WHERE col != 'A'", MODIFYING_ROWS_MESSAGE); return; } assertUpdate("DELETE FROM " + table.getName() + " WHERE col != 'A'", 1); - assertQuery("SELECT * FROM " + table.getName(), "VALUES 'A', null"); + assertQuery("SELECT col FROM " + table.getName(), "VALUES 'A', null"); } } @@ -1959,7 +1959,7 @@ public void testDeleteWithVarcharGreaterAndLowerPredicate() { skipTestUnless(hasBehavior(SUPPORTS_CREATE_TABLE) && hasBehavior(SUPPORTS_ROW_LEVEL_DELETE)); // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_varchar", "(col varchar(1))", ImmutableList.of("'0'", "'a'", "'A'", "'b'", "null"))) { + try (TestTable table = createTestTableForWrites("test_delete_varchar", "(col varchar(1), pk int)", ImmutableList.of("'0', 0", "'a', 1", "'A', 2", "'b', 3", "null, 4"), "pk")) { if (!hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY) && !hasBehavior(SUPPORTS_MERGE)) { assertQueryFails("DELETE FROM " + table.getName() + " WHERE col < 'A'", MODIFYING_ROWS_MESSAGE); assertQueryFails("DELETE FROM " + table.getName() + " WHERE col > 'A'", MODIFYING_ROWS_MESSAGE); @@ -1967,9 +1967,9 @@ public void testDeleteWithVarcharGreaterAndLowerPredicate() } assertUpdate("DELETE FROM " + table.getName() + " WHERE col < 'A'", 1); - assertQuery("SELECT * FROM " + table.getName(), "VALUES 'a', 'A', 'b', null"); + assertQuery("SELECT col FROM " + table.getName(), "VALUES 'a', 'A', 'b', null"); assertUpdate("DELETE FROM " + table.getName() + " WHERE col > 'A'", 2); - assertQuery("SELECT * FROM " + table.getName(), "VALUES 'A', null"); + assertQuery("SELECT col FROM " + table.getName(), "VALUES 'A', null"); } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java index 94ac8ba7aabb..d3c38febdf7b 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/TestJdbcWriteConfig.java @@ -34,7 +34,8 @@ public void testDefaults() assertRecordedDefaults(recordDefaults(JdbcWriteConfig.class) .setWriteBatchSize(1000) .setWriteParallelism(8) - .setNonTransactionalInsert(false)); + .setNonTransactionalInsert(false) + .setNonTransactionalMerge(false)); } @Test @@ -43,12 +44,14 @@ public void testExplicitPropertyMappings() Map properties = ImmutableMap.builder() .put("write.batch-size", "24") .put("insert.non-transactional-insert.enabled", "true") + .put("merge.non-transactional-merge.enabled", "true") .put("write.parallelism", "16") .buildOrThrow(); JdbcWriteConfig expected = new JdbcWriteConfig() .setWriteBatchSize(24) .setNonTransactionalInsert(true) + .setNonTransactionalMerge(true) .setWriteParallelism(16); assertFullMapping(properties, expected); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index 1db2f6779e8f..f182f68d5a36 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -283,9 +283,9 @@ public void verifySupportsRowLevelUpdateDeclaration() } @Override - protected String createTableForWrites(String createTable) + protected void createTableForWrites(String createTable, String tableName, Optional primaryKey, OptionalInt updateCount) { - return createTable + " WITH (transactional = true)"; + assertUpdate(createTable + " WITH (transactional = true)"); } @Override diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java index 8fdb50247e75..be6690f7542b 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java @@ -78,9 +78,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) } @Override - protected String createTableForWrites(String createTable) + protected void createTableForWrites(String createTable, String tableName, Optional primaryKey) { - return createKuduTableForWrites(createTable); + assertUpdate(createKuduTableForWrites(format(createTable, tableName))); } public static String createKuduTableForWrites(String createTable) @@ -329,7 +329,7 @@ public void testExplainAnalyzeWithDeleteWithSubquery() String tableName = "test_delete_" + randomNameSuffix(); // delete using a subquery - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO %s SELECT * FROM nation".formatted(tableName), 25); assertExplainAnalyze("EXPLAIN ANALYZE DELETE FROM " + tableName + " WHERE regionkey IN (SELECT regionkey FROM region WHERE name LIKE 'A%' LIMIT 1)", "SemiJoin.*"); @@ -639,7 +639,7 @@ public void testDelete() { // delete successive parts of the table withTableName("test_delete", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); assertUpdate("DELETE FROM " + tableName + " WHERE custkey <= 100", "SELECT count(*) FROM orders WHERE custkey <= 100"); assertQuery("SELECT * FROM " + tableName, "SELECT * FROM orders WHERE custkey > 100"); @@ -652,7 +652,7 @@ public void testDelete() }); withTableName("test_delete", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); assertUpdate("DELETE FROM " + tableName + " WHERE custkey <= 100", "SELECT count(*) FROM orders WHERE custkey <= 100"); assertQuery("SELECT * FROM " + tableName, "SELECT * FROM orders WHERE custkey > 100"); @@ -666,14 +666,14 @@ public void testDelete() // delete without matching any rows withTableName("test_delete", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); assertUpdate("DELETE FROM " + tableName + " WHERE orderkey < 0", 0); }); // delete with a predicate that optimizes to false withTableName("test_delete", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); assertUpdate("DELETE FROM " + tableName + " WHERE orderkey > 5 AND orderkey < 4", 0); }); @@ -684,7 +684,7 @@ public void testDelete() public void testDeleteWithLike() { withTableName("test_with_like", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO %s SELECT * FROM nation".formatted(tableName), 25); assertUpdate("DELETE FROM " + tableName + " WHERE name LIKE '%a%'", "VALUES 0"); assertUpdate("DELETE FROM " + tableName + " WHERE name LIKE '%A%'", "SELECT count(*) FROM nation WHERE name LIKE '%A%'"); @@ -710,7 +710,7 @@ protected TestTable createTableWithOneIntegerColumn(String namePrefix) public void testUpdateWithPredicates() { withTableName("test_update_with_predicates", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s (a INT, b INT, c INT)".formatted(tableName))); + createTableForWrites("CREATE TABLE %s (a INT, b INT, c INT)", tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " VALUES (1, 2, 3), (11, 12, 13), (21, 22, 23)", 3); assertUpdate("UPDATE " + tableName + " SET a = a - 1 WHERE c = 3", 1); assertQuery("SELECT * FROM " + tableName, "VALUES (0, 2, 3), (11, 12, 13), (21, 22, 23)"); @@ -732,7 +732,7 @@ public void testUpdateWithPredicates() public void testUpdateAllValues() { withTableName("test_update_all_columns", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s (a INT, b INT, c INT)".formatted(tableName))); + createTableForWrites("CREATE TABLE %s (a INT, b INT, c INT)", tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " VALUES (1, 2, 3), (11, 12, 13), (21, 22, 23)", 3); assertUpdate("UPDATE " + tableName + " SET b = b - 1, c = c * 2", 3); assertQuery("SELECT * FROM " + tableName, "VALUES (1, 1, 6), (11, 11, 26), (21, 21, 46)"); @@ -792,7 +792,7 @@ public void testLimitPushdown() public void testDeleteWithComplexPredicate() { withTableName("test_delete_complex", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); assertUpdate("DELETE FROM " + tableName + " WHERE orderkey % 2 = 0", "SELECT count(*) FROM orders WHERE orderkey % 2 = 0"); assertQuery("SELECT * FROM " + tableName, "SELECT * FROM orders WHERE orderkey % 2 <> 0"); @@ -810,7 +810,7 @@ public void testDeleteWithSubquery() { // TODO (https://github.com/trinodb/trino/issues/13210) Migrate these tests to AbstractTestEngineOnlyQueries withTableName("test_delete_subquery", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); assertUpdate("DELETE FROM " + tableName + " WHERE regionkey IN (SELECT regionkey FROM region WHERE name LIKE 'A%')", 15); assertQuery( @@ -819,7 +819,7 @@ public void testDeleteWithSubquery() }); withTableName("test_delete_subquery", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); // delete using a scalar and EXISTS subquery @@ -830,7 +830,7 @@ public void testDeleteWithSubquery() }); withTableName("test_delete_subquery", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); // delete using correlated EXISTS subquery @@ -841,7 +841,7 @@ public void testDeleteWithSubquery() }); withTableName("test_delete_subquery", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); // delete using correlated IN subquery assertUpdate(format("DELETE FROM %1$s WHERE regionkey IN (SELECT regionkey FROM region WHERE regionkey = %1$s.regionkey AND name LIKE 'A%%')", tableName), 15); @@ -856,7 +856,7 @@ public void testDeleteWithSubquery() public void testDeleteWithSemiJoin() { withTableName("test_delete_semijoin", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); // delete with multiple SemiJoin assertUpdate( @@ -872,7 +872,7 @@ public void testDeleteWithSemiJoin() }); withTableName("test_delete_semijoin", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); // delete with SemiJoin null handling @@ -893,7 +893,7 @@ public void testDeleteWithSemiJoin() public void testDeleteWithVarcharPredicate() { withTableName("test_delete_varchar", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, ORDER_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + ORDER_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM orders", 15000); assertUpdate("DELETE FROM " + tableName + " WHERE orderstatus = 'O'", "SELECT count(*) FROM orders WHERE orderstatus = 'O'"); assertQuery("SELECT * FROM " + tableName, "SELECT * FROM orders WHERE orderstatus <> 'O'"); @@ -905,7 +905,7 @@ public void testDeleteWithVarcharPredicate() public void testDeleteAllDataFromTable() { withTableName("test_delete_all_data", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, REGION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + REGION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM region", 5); // not using assertUpdate as some connectors provide update count and some not @@ -919,7 +919,7 @@ public void testDeleteAllDataFromTable() public void testRowLevelDelete() { withTableName("test_row_delete", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, REGION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + REGION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM region", 5); assertUpdate("DELETE FROM " + tableName + " WHERE regionkey = 2", 1); assertQuery("SELECT count(*) FROM " + tableName, "VALUES 4"); @@ -936,7 +936,7 @@ public void testRowLevelDelete() public void testUpdate() { withTableName("test_update", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); assertUpdate("UPDATE " + tableName + " SET nationkey = 100 WHERE regionkey = 2", 5); assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey = 100", "VALUES 5"); @@ -962,7 +962,7 @@ public void testUpdateMultipleCondition() {} public void testRowLevelUpdate() { withTableName("test_update", tableName -> { - assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + createTableForWrites("CREATE TABLE %s " + NATION_COLUMNS, tableName, Optional.empty()); assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); assertUpdate("UPDATE " + tableName + " SET nationkey = 100 + nationkey WHERE regionkey = 2", 5); assertThat(query("SELECT * FROM " + tableName)) diff --git a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java index d96993a2b3ba..4ee1225eb034 100644 --- a/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java +++ b/plugin/trino-phoenix5/src/main/java/io/trino/plugin/phoenix5/PhoenixClient.java @@ -801,6 +801,12 @@ public void setColumnType(ConnectorSession session, JdbcTableHandle handle, Jdbc throw new TrinoException(NOT_SUPPORTED, "This connector does not support setting column types"); } + @Override + public boolean supportsMerge() + { + return true; + } + @Override public List getPrimaryKeys(ConnectorSession session, RemoteTableName remoteTableName) { diff --git a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java index dede278b28ff..4797960ab08d 100644 --- a/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java +++ b/plugin/trino-phoenix5/src/test/java/io/trino/plugin/phoenix5/TestPhoenixConnectorTest.java @@ -49,6 +49,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.jdbc.JdbcMetadataSessionProperties.DOMAIN_COMPACTION_THRESHOLD; +import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.UNSUPPORTED_TYPE_HANDLING; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; @@ -118,6 +119,15 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) }; } + @Override + protected Session getSession() + { + Session session = super.getSession(); + return Session.builder(session) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true") + .build(); + } + // TODO: wait https://github.com/trinodb/trino/pull/14939 done and then remove this test @Test @Override @@ -422,7 +432,7 @@ private void testMergeWithSpecifiedRowkeys(String rowkeyDefinition) String targetTable = "merge_multiple_rowkeys_specified_" + randomNameSuffix(); // check the upper case table name also works targetTable = targetTable.toUpperCase(ENGLISH); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer_copy VARCHAR) WITH (rowkeys = '%s')", targetTable, rowkeyDefinition))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR, customer_copy VARCHAR) WITH (rowkeys = '" + rowkeyDefinition + "')", targetTable, Optional.empty(), OptionalInt.empty()); String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct', 'joe_%s')", intValue, 1000, 91000, intValue, intValue, intValue)) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 3108c1be783d..7cc262b79ea9 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -132,6 +132,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.HexFormat; import java.util.List; import java.util.Map; @@ -139,6 +140,7 @@ import java.util.Optional; import java.util.OptionalInt; import java.util.OptionalLong; +import java.util.Set; import java.util.UUID; import java.util.function.BiFunction; import java.util.function.Predicate; @@ -147,6 +149,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; import static io.airlift.slice.Slices.wrappedBuffer; @@ -919,6 +922,12 @@ public boolean isLimitGuaranteed(ConnectorSession session) return true; } + @Override + public boolean supportsMerge() + { + return true; + } + @Override public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) { @@ -1188,6 +1197,32 @@ protected void verifyColumnName(DatabaseMetaData databaseMetadata, String column } } + @Override + public List getPrimaryKeys(ConnectorSession session, RemoteTableName remoteTableName) + { + List columns = getColumns(session, remoteTableName.getSchemaTableName(), remoteTableName); + try (Connection connection = connectionFactory.openConnection(session)) { + DatabaseMetaData metaData = connection.getMetaData(); + + ResultSet primaryKeys = metaData.getPrimaryKeys(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + + Set primaryKeyNames = new HashSet<>(); + while (primaryKeys.next()) { + String columnName = primaryKeys.getString("COLUMN_NAME"); + primaryKeyNames.add(columnName); + } + if (primaryKeyNames.isEmpty()) { + return ImmutableList.of(); + } + return columns.stream() + .filter(column -> primaryKeyNames.contains(column.getColumnName())) + .collect(toImmutableList()); + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + private static ColumnMapping charColumnMapping(int charLength) { if (charLength > CharType.MAX_LENGTH) { diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java index c9cf317e7dba..c5f49addd7db 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/BasePostgresFailureRecoveryTest.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Module; +import io.trino.Session; import io.trino.operator.RetryPolicy; import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; import io.trino.plugin.jdbc.BaseJdbcFailureRecoveryTest; @@ -32,6 +33,8 @@ public abstract class BasePostgresFailureRecoveryTest extends BaseJdbcFailureRecoveryTest { + private TestingPostgreSqlServer postgreSqlServer; + public BasePostgresFailureRecoveryTest(RetryPolicy retryPolicy) { super(retryPolicy); @@ -45,7 +48,8 @@ protected QueryRunner createQueryRunner( Module failureInjectionModule) throws Exception { - return PostgreSqlQueryRunner.builder(closeAfterClass(new TestingPostgreSqlServer())) + this.postgreSqlServer = new TestingPostgreSqlServer(); + return PostgreSqlQueryRunner.builder(closeAfterClass(this.postgreSqlServer)) .setExtraProperties(configProperties) .setCoordinatorProperties(configProperties) .setAdditionalSetup(runner -> { @@ -58,6 +62,14 @@ protected QueryRunner createQueryRunner( .build(); } + @Test + @Override + protected void testDeleteWithSubquery() + { + // TODO: support merge with fte https://github.com/trinodb/trino/issues/23345 + assertThatThrownBy(super::testDeleteWithSubquery).hasMessageContaining("Non-transactional MERGE is disabled"); + } + @Test @Override protected void testUpdateWithSubquery() @@ -66,6 +78,14 @@ protected void testUpdateWithSubquery() abort("skipped"); } + @Test + @Override + protected void testMerge() + { + // TODO: support merge with fte https://github.com/trinodb/trino/issues/23345 + assertThatThrownBy(super::testMerge).hasMessageContaining("Non-transactional MERGE is disabled"); + } + @Test @Override protected void testUpdate() @@ -81,4 +101,10 @@ protected void testUpdate() .withCleanupQuery(cleanupQuery) .isCoordinatorOnly(); } + + @Override + protected void addPrimaryKeyForMergeTarget(Session session, String tableName, String primaryKey) + { + postgreSqlServer.execute("ALTER TABLE %s ADD CONSTRAINT pk_%s PRIMARY KEY (%s)".formatted(tableName, tableName, primaryKey)); + } } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 229bde0d5164..345ea11fe68f 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -62,6 +62,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE; import static io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping.AS_ARRAY; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -107,6 +108,15 @@ protected QueryRunner createQueryRunner() .build(); } + @Override + protected Session getSession() + { + Session session = super.getSession(); + return Session.builder(session) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true") + .build(); + } + @BeforeAll public void setExtensions() { @@ -128,7 +138,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_JOIN_PUSHDOWN, SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY, SUPPORTS_TOPN_PUSHDOWN, - SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR -> true; + SUPPORTS_TOPN_PUSHDOWN_WITH_VARCHAR, + SUPPORTS_ROW_LEVEL_UPDATE, + SUPPORTS_MERGE -> true; case SUPPORTS_ADD_COLUMN_WITH_COMMENT, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN, @@ -1297,6 +1309,23 @@ void testVectorDistanceNotPushdown() } } + @Test + public void testMergeTargetWithNoPrimaryKeys() + { + String tableName = "test_merge_target_no_pks_" + randomNameSuffix(); + assertUpdate("CREATE TABLE " + tableName + " (a int, b int)"); + assertUpdate("INSERT INTO " + tableName + " VALUES(1, 1), (2, 2)", 2); + + assertQueryFails(format("DELETE FROM %s WHERE a IS NOT NULL AND abs(a + b) > 10", tableName), "The connector can not perform merge on the target table without primary keys"); + assertQueryFails(format("UPDATE %s SET a = a+b WHERE a IS NOT NULL AND (a + b) > 10", tableName), "The connector can not perform merge on the target table without primary keys"); + assertQueryFails(format("MERGE INTO %s t USING (VALUES (3, 3)) AS s(x, y) " + + " ON t.a = s.x " + + " WHEN MATCHED THEN UPDATE SET b = y " + + " WHEN NOT MATCHED THEN INSERT (a, b) VALUES (s.x, s.y) ", tableName), "The connector can not perform merge on the target table without primary keys"); + + assertUpdate("DROP TABLE " + tableName); + } + private String getLongInClause(int start, int length) { String longValues = range(start, start + length) @@ -1407,4 +1436,29 @@ protected Optional filterSetColumnTypesDataProvider(SetColum return Optional.of(setup); } + + @Override + protected void createTableForWrites(String createTable, String tableName, Optional primaryKey, OptionalInt updateCount) + { + super.createTableForWrites(createTable, tableName, primaryKey, updateCount); + primaryKey.ifPresent(key -> onRemoteDatabase().execute(format("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", tableName, "pk_" + tableName, key))); + } + + @Override + protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, String primaryKey) + { + TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, primaryKey); + String tableName = testTable.getName(); + onRemoteDatabase().execute(format("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", tableName, "pk_" + tableName, primaryKey)); + return testTable; + } + + @Override + protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, List rowsToInsert, String primaryKey) + { + TestTable testTable = super.createTestTableForWrites(namePrefix, tableDefinition, rowsToInsert, primaryKey); + String tableName = testTable.getName(); + onRemoteDatabase().execute(format("ALTER TABLE %s ADD CONSTRAINT %s PRIMARY KEY (%s)", tableName, "pk_" + tableName, primaryKey)); + return testTable; + } } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionAccesses.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionAccesses.java index cf69cc8ae1eb..e09d7324c609 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionAccesses.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionAccesses.java @@ -19,6 +19,7 @@ import com.google.inject.Provides; import com.google.inject.Singleton; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectionCreationTest; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DriverConnectionFactory; @@ -35,9 +36,10 @@ import java.util.Properties; import static io.airlift.configuration.ConfigurationAwareModule.combine; +import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.testing.QueryAssertions.copyTpchTables; +import static io.trino.tpch.TpchTable.CUSTOMER; import static io.trino.tpch.TpchTable.NATION; import static io.trino.tpch.TpchTable.REGION; import static java.util.Objects.requireNonNull; @@ -47,11 +49,14 @@ public class TestPostgreSqlJdbcConnectionAccesses extends BaseJdbcConnectionCreationTest { + private TestingPostgreSqlServer postgreSqlServer; + @Override protected QueryRunner createQueryRunner() throws Exception { TestingPostgreSqlServer postgreSqlServer = closeAfterClass(new TestingPostgreSqlServer()); + this.postgreSqlServer = requireNonNull(postgreSqlServer, "postgreSqlServer is null"); this.connectionFactory = getConnectionCountingConnectionFactory(postgreSqlServer); DistributedQueryRunner queryRunner = PostgreSqlQueryRunner.builder(postgreSqlServer) // to make sure we always open connections in the same way @@ -69,10 +74,19 @@ protected QueryRunner createQueryRunner() "query.reuse-connection", "false")); }) .build(); - copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, ImmutableList.of(NATION, REGION)); + copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, ImmutableList.of(CUSTOMER, NATION, REGION)); return queryRunner; } + @Override + protected Session getSession() + { + Session session = super.getSession(); + return Session.builder(session) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true") + .build(); + } + private static ConnectionCountingConnectionFactory getConnectionCountingConnectionFactory(TestingPostgreSqlServer postgreSqlServer) { Properties connectionProperties = new Properties(); @@ -97,19 +111,34 @@ public void testJdbcConnectionCreations() assertJdbcConnections("SELECT * FROM nation JOIN region USING(regionkey)", 10, Optional.empty()); assertJdbcConnections("SELECT * FROM information_schema.schemata", 1, Optional.empty()); assertJdbcConnections("SELECT * FROM information_schema.tables", 1, Optional.empty()); - assertJdbcConnections("SELECT * FROM information_schema.columns", 5, Optional.empty()); + assertJdbcConnections("SELECT * FROM information_schema.columns", 7, Optional.empty()); assertJdbcConnections("SELECT * FROM nation", 3, Optional.empty()); assertJdbcConnections("SELECT * FROM TABLE (system.query(query => 'SELECT * FROM tpch.nation'))", 2, Optional.empty()); assertJdbcConnections("CREATE TABLE copy_of_nation AS SELECT * FROM nation", 15, Optional.empty()); assertJdbcConnections("INSERT INTO copy_of_nation SELECT * FROM nation", 13, Optional.empty()); - assertJdbcConnections("DELETE FROM copy_of_nation WHERE nationkey = 3", 4, Optional.empty()); - assertJdbcConnections("UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 4, Optional.empty()); - assertJdbcConnections("MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 5, Optional.of(MODIFYING_ROWS_MESSAGE)); + assertJdbcConnections("DELETE FROM copy_of_nation WHERE nationkey = 3", 6, Optional.empty()); assertJdbcConnections("DROP TABLE copy_of_nation", 2, Optional.empty()); assertJdbcConnections("SHOW SCHEMAS", 1, Optional.empty()); assertJdbcConnections("SHOW TABLES", 2, Optional.empty()); assertJdbcConnections("SHOW STATS FOR nation", 4, Optional.empty()); - assertJdbcConnections("SELECT * FROM system.jdbc.columns WHERE table_cat = 'counting_postgresql'", 5, Optional.empty()); + assertJdbcConnections("SELECT * FROM system.jdbc.columns WHERE table_cat = 'counting_postgresql'", 7, Optional.empty()); + + testJdbcMergeConnectionCreations(); + } + + private void testJdbcMergeConnectionCreations() + { + assertJdbcConnections("CREATE TABLE copy_of_customer AS SELECT * FROM customer", 15, Optional.empty()); + + addPrimaryKeyToCopyTable(); + assertJdbcConnections("DELETE FROM copy_of_customer WHERE abs(custkey) = 1", 24, Optional.empty()); + assertJdbcConnections("UPDATE copy_of_customer SET name = 'POLAND' WHERE abs(custkey) = 1", 32, Optional.empty()); + assertJdbcConnections("MERGE INTO copy_of_customer c USING customer r ON r.custkey = c.custkey WHEN MATCHED THEN DELETE", 28, Optional.empty()); + } + + private void addPrimaryKeyToCopyTable() + { + postgreSqlServer.execute("ALTER TABLE copy_of_customer ADD CONSTRAINT t_copy_of_customer PRIMARY KEY (custkey)"); } private static final class TestingPostgreSqlModule diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java index 1208b5880fea..7afea0c3a528 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java @@ -19,6 +19,7 @@ import com.google.inject.Provides; import com.google.inject.Singleton; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectionCreationTest; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.DriverConnectionFactory; @@ -35,9 +36,10 @@ import java.util.Properties; import static io.airlift.configuration.ConfigurationAwareModule.combine; +import static io.trino.plugin.jdbc.JdbcWriteSessionProperties.NON_TRANSACTIONAL_MERGE; import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; -import static io.trino.spi.connector.ConnectorMetadata.MODIFYING_ROWS_MESSAGE; import static io.trino.testing.QueryAssertions.copyTpchTables; +import static io.trino.tpch.TpchTable.CUSTOMER; import static io.trino.tpch.TpchTable.NATION; import static io.trino.tpch.TpchTable.REGION; import static java.util.Objects.requireNonNull; @@ -45,11 +47,14 @@ public class TestPostgreSqlJdbcConnectionCreation extends BaseJdbcConnectionCreationTest { + protected TestingPostgreSqlServer postgreSqlServer; + @Override protected QueryRunner createQueryRunner() throws Exception { TestingPostgreSqlServer postgreSqlServer = closeAfterClass(new TestingPostgreSqlServer()); + this.postgreSqlServer = requireNonNull(postgreSqlServer, "postgreSqlServer is null"); this.connectionFactory = getConnectionCountingConnectionFactory(postgreSqlServer); DistributedQueryRunner queryRunner = PostgreSqlQueryRunner.builder(postgreSqlServer) // to make sure we always open connections in the same way @@ -65,10 +70,19 @@ protected QueryRunner createQueryRunner() "connection-password", postgreSqlServer.getPassword())); }) .build(); - copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, ImmutableList.of(NATION, REGION)); + copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, ImmutableList.of(CUSTOMER, NATION, REGION)); return queryRunner; } + @Override + protected Session getSession() + { + Session session = super.getSession(); + return Session.builder(session) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), NON_TRANSACTIONAL_MERGE, "true") + .build(); + } + private static ConnectionCountingConnectionFactory getConnectionCountingConnectionFactory(TestingPostgreSqlServer postgreSqlServer) { Properties connectionProperties = new Properties(); @@ -100,12 +114,28 @@ public void testJdbcConnectionCreations() assertJdbcConnections("INSERT INTO copy_of_nation SELECT * FROM nation", 6, Optional.empty()); assertJdbcConnections("DELETE FROM copy_of_nation WHERE nationkey = 3", 1, Optional.empty()); assertJdbcConnections("UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 1, Optional.empty()); - assertJdbcConnections("MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 1, Optional.of(MODIFYING_ROWS_MESSAGE)); assertJdbcConnections("DROP TABLE copy_of_nation", 1, Optional.empty()); assertJdbcConnections("SHOW SCHEMAS", 1, Optional.empty()); assertJdbcConnections("SHOW TABLES", 1, Optional.empty()); assertJdbcConnections("SHOW STATS FOR nation", 2, Optional.empty()); assertJdbcConnections("SELECT * FROM system.jdbc.columns WHERE table_cat = 'counting_postgresql'", 1, Optional.empty()); + + testJdbcMergeConnectionCreations(); + } + + private void testJdbcMergeConnectionCreations() + { + assertJdbcConnections("CREATE TABLE copy_of_customer AS SELECT * FROM customer", 6, Optional.empty()); + + addPrimaryKeyToCopyTable(); + assertJdbcConnections("DELETE FROM copy_of_customer WHERE abs(custkey) = 1", 17, Optional.empty()); + assertJdbcConnections("UPDATE copy_of_customer SET name = 'POLAND' WHERE abs(custkey) = 1", 25, Optional.empty()); + assertJdbcConnections("MERGE INTO copy_of_customer c USING customer r ON r.custkey = c.custkey WHEN MATCHED THEN DELETE", 18, Optional.empty()); + } + + private void addPrimaryKeyToCopyTable() + { + postgreSqlServer.execute("ALTER TABLE copy_of_customer ADD CONSTRAINT t_copy_of_nation PRIMARY KEY (custkey)"); } private static final class TestingPostgreSqlModule diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java index 9a083d840c81..3fa0c0cb27fc 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestRemoteQueryCommentLogging.java @@ -62,7 +62,8 @@ public void testShouldLogContextInComment() assertThat(postgreSqlServer.recordEventsForOperations(() -> getQueryRunner().execute("DELETE FROM postgresql.tpch.log_nation_test_table")) .stopEventsRecording() - .streamQueriesContaining("log_nation_test_table")) + // Filter that the identifier not the variable + .streamQueriesContaining("\"log_nation_test_table\"")) .allMatch(query -> query.endsWith("/*query executed by user*/")) .size() .isEqualTo(1); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 04bbb9cbf95f..e53c808fa9bf 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -4649,7 +4649,7 @@ public void testUpdateNotNullColumn() return; } - try (TestTable table = new TestTable(getQueryRunner()::execute, "update_not_null", "(nullable_col INTEGER, not_null_col INTEGER NOT NULL)")) { + try (TestTable table = createTestTableForWrites("update_not_null", "(nullable_col INTEGER, not_null_col INTEGER NOT NULL)", "not_null_col")) { assertUpdate(format("INSERT INTO %s (nullable_col, not_null_col) VALUES (1, 10)", table.getName()), 1); assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 10)"); assertQueryFails("UPDATE " + table.getName() + " SET not_null_col = NULL WHERE nullable_col = 1", "NULL value not allowed for NOT NULL column: not_null_col"); @@ -4788,7 +4788,7 @@ public void testDeleteWithComplexPredicate() skipTestUnless(hasBehavior(SUPPORTS_DELETE)); // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_complex_", "AS SELECT * FROM nation")) { + try (TestTable table = createTestTableForWrites("test_delete_complex_", "AS SELECT * FROM nation", "nationkey")) { // delete half the table, then delete the rest assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey % 2 = 0", "SELECT count(*) FROM nation WHERE nationkey % 2 = 0"); assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey % 2 <> 0"); @@ -4807,7 +4807,7 @@ public void testDeleteWithSubquery() skipTestUnless(hasBehavior(SUPPORTS_DELETE)); // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_subquery", "AS SELECT * FROM nation")) { + try (TestTable table = createTestTableForWrites("test_delete_subquery", "AS SELECT * FROM nation", "nationkey")) { // delete using a subquery assertUpdate("DELETE FROM " + table.getName() + " WHERE regionkey IN (SELECT regionkey FROM region WHERE name LIKE 'A%')", 15); assertQuery( @@ -4816,7 +4816,7 @@ public void testDeleteWithSubquery() } // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_subquery", "AS SELECT * FROM orders")) { + try (TestTable table = createTestTableForWrites("test_delete_subquery", "AS SELECT * FROM orders", "orderkey")) { // delete using a scalar and EXISTS subquery assertUpdate("DELETE FROM " + table.getName() + " WHERE orderkey = (SELECT orderkey FROM orders ORDER BY orderkey LIMIT 1)", 1); assertUpdate("DELETE FROM " + table.getName() + " WHERE orderkey = (SELECT orderkey FROM orders WHERE false)", 0); @@ -4824,7 +4824,7 @@ public void testDeleteWithSubquery() assertUpdate("DELETE FROM " + table.getName() + " WHERE EXISTS(SELECT 1)", "SELECT count(*) - 1 FROM orders"); } - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_correlated_exists_subquery", "AS SELECT * FROM nation")) { + try (TestTable table = createTestTableForWrites("test_delete_correlated_exists_subquery", "AS SELECT * FROM nation", "nationkey")) { // delete using correlated EXISTS subquery assertUpdate(format("DELETE FROM %1$s WHERE EXISTS(SELECT regionkey FROM region WHERE regionkey = %1$s.regionkey AND name LIKE 'A%%')", table.getName()), 15); assertQuery( @@ -4832,7 +4832,7 @@ public void testDeleteWithSubquery() "SELECT * FROM nation WHERE regionkey IN (SELECT regionkey FROM region WHERE name NOT LIKE 'A%')"); } - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_correlated_exists_subquery", "AS SELECT * FROM nation")) { + try (TestTable table = createTestTableForWrites("test_delete_correlated_exists_subquery", "AS SELECT * FROM nation", "nationkey")) { // delete using correlated IN subquery assertUpdate(format("DELETE FROM %1$s WHERE regionkey IN (SELECT regionkey FROM region WHERE regionkey = %1$s.regionkey AND name LIKE 'A%%')", table.getName()), 15); assertQuery( @@ -4841,18 +4841,27 @@ public void testDeleteWithSubquery() } } + protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, String primaryKey) + { + return new TestTable(getQueryRunner()::execute, namePrefix, tableDefinition); + } + + protected TestTable createTestTableForWrites(String namePrefix, String tableDefinition, List rowsToInsert, String primaryKey) + { + return new TestTable(getQueryRunner()::execute, namePrefix, tableDefinition, rowsToInsert); + } + @Test public void testExplainAnalyzeWithDeleteWithSubquery() { skipTestUnless(hasBehavior(SUPPORTS_DELETE)); - String tableName = "test_delete_" + randomNameSuffix(); - - // delete using a subquery - assertUpdate("CREATE TABLE " + tableName + " AS SELECT * FROM nation", 25); - assertExplainAnalyze("EXPLAIN ANALYZE DELETE FROM " + tableName + " WHERE regionkey IN (SELECT regionkey FROM region WHERE name LIKE 'A%' LIMIT 1)", - "SemiJoin.*"); - assertUpdate("DROP TABLE " + tableName); + try (TestTable table = createTestTableForWrites("test_delete_", "AS SELECT * FROM nation", "nationkey")) { + String tableName = table.getName(); + // delete using a subquery + assertExplainAnalyze("EXPLAIN ANALYZE DELETE FROM " + tableName + " WHERE regionkey IN (SELECT regionkey FROM region WHERE name LIKE 'A%' LIMIT 1)", + "SemiJoin.*"); + } } @Test @@ -4861,7 +4870,7 @@ public void testDeleteWithSemiJoin() skipTestUnless(hasBehavior(SUPPORTS_DELETE)); // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_semijoin", "AS SELECT * FROM nation")) { + try (TestTable table = createTestTableForWrites("test_delete_semijoin", "AS SELECT * FROM nation", "nationkey")) { // delete with multiple SemiJoin assertUpdate( "DELETE FROM " + table.getName() + " " + @@ -4876,7 +4885,7 @@ public void testDeleteWithSemiJoin() } // TODO (https://github.com/trinodb/trino/issues/5901) Use longer table name once Oracle version is updated - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_semijoin", "AS SELECT * FROM orders")) { + try (TestTable table = createTestTableForWrites("test_delete_semijoin", "AS SELECT * FROM orders", "orderkey")) { // delete with SemiJoin null handling assertUpdate( "DELETE FROM " + table.getName() + "\n" + @@ -5014,7 +5023,7 @@ public void testRowLevelUpdate() return; } - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update", "AS TABLE tpch.tiny.nation")) { + try (TestTable table = createTestTableForWrites("test_update", "AS TABLE tpch.tiny.nation", "name,regionkey")) { String tableName = table.getName(); assertUpdate("UPDATE " + tableName + " SET nationkey = 100 + nationkey WHERE regionkey = 2", 5); assertThat(query("SELECT * FROM " + tableName)) @@ -5054,21 +5063,21 @@ public void testUpdateRowConcurrently() int threads = 4; CyclicBarrier barrier = new CyclicBarrier(threads); ExecutorService executor = newFixedThreadPool(threads); - try (TestTable table = new TestTable( - getQueryRunner()::execute, + try (TestTable table = createTestTableForWrites( "test_concurrent_update", - IntStream.range(0, threads) + IntStream.range(0, threads + 1) .mapToObj(i -> format("col%s integer", i)) - .collect(joining(", ", "(", ")")))) { + .collect(joining(", ", "(", ")")), + "col" + threads)) { String tableName = table.getName(); - assertUpdate(format("INSERT INTO %s VALUES (%s)", tableName, join(",", nCopies(threads, "0"))), 1); + assertUpdate(format("INSERT INTO %s VALUES (%s)", tableName, join(",", nCopies(threads + 1, "0"))), 1); List> futures = IntStream.range(0, threads) .mapToObj(threadNumber -> executor.submit(() -> { barrier.await(10, SECONDS); try { String columnName = "col" + threadNumber; - getQueryRunner().execute(format("UPDATE %s SET %s = %s + 1", tableName, columnName, columnName)); + getQueryRunner().execute(getSession(), format("UPDATE %s SET %s = %s + 1", tableName, columnName, columnName)); return true; } catch (Exception e) { @@ -5090,7 +5099,7 @@ public void testUpdateRowConcurrently() String expected = futures.stream() .map(future -> tryGetFutureValue(future, 10, SECONDS).orElseThrow(() -> new RuntimeException("Wait timed out"))) .map(success -> success ? "1" : "0") - .collect(joining(",", "VALUES (", ")")); + .collect(joining(",", "VALUES (", ", 0)")); assertThat(query("TABLE " + tableName)) .matches(expected); @@ -5339,7 +5348,7 @@ public void testUpdateWithPredicates() return; } - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_with_predicates", "(a INT, b INT, c INT)")) { + try (TestTable table = createTestTableForWrites("test_update_with_predicates", "(a INT, b INT, c INT)", "a")) { String tableName = table.getName(); assertUpdate("INSERT INTO " + tableName + " VALUES (1, 2, 3), (11, 12, 13), (21, 22, 23)", 3); assertUpdate("UPDATE " + tableName + " SET a = a - 1 WHERE c = 3", 1); @@ -5402,7 +5411,7 @@ public void testUpdateAllValues() return; } - try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_all_columns", "(a INT, b INT, c INT)")) { + try (TestTable table = createTestTableForWrites("test_update_all_columns", "(a INT, b INT, c INT)", "a")) { String tableName = table.getName(); assertUpdate("INSERT INTO " + tableName + " VALUES (1, 2, 3), (11, 12, 13), (21, 22, 23)", 3); assertUpdate("UPDATE " + tableName + " SET a = a + 1, b = b - 1, c = c * 2", 3); @@ -6199,7 +6208,7 @@ public void testMergeDeleteWithCTAS() (4, 'd', 'dd') ) AS t (id, name, value) """; - assertUpdate(createTableSql.formatted(target), 4); + createTableForWrites(createTableSql, target, Optional.of("id"), OptionalInt.of(4)); assertUpdate(createTableSql.formatted(source), 4); assertQuery("SELECT COUNT(*) FROM " + target, "VALUES 4"); @@ -6212,9 +6221,14 @@ public void testMergeDeleteWithCTAS() assertUpdate("DROP TABLE " + source); } - protected String createTableForWrites(String createTable) + protected void createTableForWrites(String createTable, String tableName, Optional primaryKey) + { + createTableForWrites(createTable, tableName, primaryKey, OptionalInt.empty()); + } + + protected void createTableForWrites(String createTable, String tableName, Optional primaryKey, OptionalInt updateCount) { - return createTable; + updateCount.ifPresentOrElse(count -> assertUpdate(format(createTable, tableName), count), () -> assertUpdate(format(createTable, tableName))); } @Test @@ -6224,7 +6238,7 @@ public void testMergeLarge() String tableName = "test_merge_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (orderkey BIGINT, custkey BIGINT, totalprice DOUBLE)", tableName))); + createTableForWrites("CREATE TABLE %s (orderkey BIGINT, custkey BIGINT, totalprice DOUBLE)", tableName, Optional.of("orderkey")); assertUpdate( format("INSERT INTO %s SELECT orderkey, custkey, totalprice FROM tpch.sf1.orders", tableName), @@ -6258,11 +6272,11 @@ public void testMergeSimpleSelect() String targetTable = "merge_simple_target_" + randomNameSuffix(); String sourceTable = "merge_simple_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); @@ -6284,11 +6298,11 @@ public void testMergeFruits() String targetTable = "merge_various_target_" + randomNameSuffix(); String sourceTable = "merge_various_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Dave', 'dates'), ('Lou', 'limes'), ('Carol', 'candles')", targetTable), 3); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchase VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (customer, purchase) VALUES ('Craig', 'candles'), ('Len', 'limes'), ('Joe', 'jellybeans')", sourceTable), 3); @@ -6310,7 +6324,7 @@ public void testMergeMultipleOperations() int targetCustomerCount = 32; String targetTable = "merge_multiple_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, zipcode INT, spouse VARCHAR, address VARCHAR)", targetTable, Optional.of("customer")); String originalInsertFirstHalf = IntStream.range(1, targetCustomerCount / 2) .mapToObj(intValue -> format("('joe_%s', %s, %s, 'jan_%s', '%s Poe Ct')", intValue, 1000, 91000, intValue, intValue)) @@ -6375,7 +6389,7 @@ public void testMergeSimpleQuery() skipTestUnless(hasBehavior(SUPPORTS_MERGE)); String targetTable = "merge_query_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); @@ -6398,7 +6412,7 @@ public void testMergeAllInserts() skipTestUnless(hasBehavior(SUPPORTS_MERGE)); String targetTable = "merge_inserts_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable), 2); @@ -6419,7 +6433,7 @@ public void testMergeFalseJoinCondition() skipTestUnless(hasBehavior(SUPPORTS_MERGE)); String targetTable = "merge_join_false_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 11, 'Antioch'), ('Bill', 7, 'Buena')", targetTable), 2); @@ -6467,11 +6481,11 @@ public void testMergeAllColumnsUpdated() String targetTable = "merge_all_columns_updated_target_" + randomNameSuffix(); String sourceTable = "merge_all_columns_updated_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Devon'), ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge')", targetTable), 4); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire'), ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Ed', 7, 'Etherville')", sourceTable), 4); @@ -6492,11 +6506,11 @@ public void testMergeAllMatchesDeleted() String targetTable = "merge_all_matches_deleted_target_" + randomNameSuffix(); String sourceTable = "merge_all_matches_deleted_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire'), ('Ed', 7, 'Etherville')", sourceTable), 4); @@ -6518,11 +6532,11 @@ public void testMergeMultipleRowsMatchFails() String targetTable = "merge_multiple_fail_target_" + randomNameSuffix(); String sourceTable = "merge_multiple_fail_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Antioch')", targetTable), 2); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Aaron', 6, 'Adelphi'), (2, 'Aaron', 8, 'Ashland')", sourceTable), 2); @@ -6545,7 +6559,7 @@ public void testMergeQueryWithStrangeCapitalization() skipTestUnless(hasBehavior(SUPPORTS_MERGE)); String targetTable = "merge_strange_capitalization_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); @@ -6569,11 +6583,11 @@ public void testMergeWithoutTablesAliases() String targetTable = "test_without_aliases_target_" + randomNameSuffix(); String sourceTable = "test_without_aliases_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("customer")); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 5, 'Antioch'), ('Bill', 7, 'Buena'), ('Carol', 3, 'Cambridge'), ('Dave', 11, 'Devon')", targetTable), 4); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Aaron', 6, 'Arches'), ('Ed', 7, 'Etherville'), ('Carol', 9, 'Centreville'), ('Dave', 11, 'Darbyshire')", sourceTable), 4); @@ -6598,11 +6612,11 @@ public void testMergeWithUnpredictablePredicates() String targetTable = "merge_predicates_target_" + randomNameSuffix(); String sourceTable = "merge_predicates_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("id")); assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Aaron', 5, 'Antioch'), (2, 'Bill', 7, 'Buena'), (3, 'Carol', 3, 'Cambridge'), (4, 'Dave', 11, 'Devon')", targetTable), 4); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (5, 'Aaron', 6, 'Arches'), (6, 'Carol', 9, 'Centreville'), (7, 'Dave', 11, 'Darbyshire'), (8, 'Ed', 7, 'Etherville')", sourceTable), 4); @@ -6646,11 +6660,11 @@ public void testMergeWithSimplifiedUnpredictablePredicates() String targetTable = "merge_predicates_target_" + randomNameSuffix(); String sourceTable = "merge_predicates_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (id INT, customer VARCHAR, purchases INT, address VARCHAR)", targetTable, Optional.of("id")); assertUpdate(format("INSERT INTO %s (id, customer, purchases, address) VALUES (1, 'Dave', 11, 'Devon'), (2, 'Dave', 11, 'Darbyshire')", targetTable), 2); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (customer VARCHAR, purchases INT, address VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s (customer, purchases, address) VALUES ('Dave', 11, 'Darbyshire')", sourceTable), 1); @@ -6674,11 +6688,11 @@ public void testMergeCasts() String targetTable = "merge_cast_target_" + randomNameSuffix(); String sourceTable = "merge_cast_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (col1 INT, col2 DOUBLE, col3 INT, col4 BIGINT, col5 REAL, col6 DOUBLE)", targetTable))); + createTableForWrites("CREATE TABLE %s (col1 INT, col2 DOUBLE, col3 INT, col4 BIGINT, col5 REAL, col6 DOUBLE)", targetTable, Optional.of("col1")); assertUpdate(format("INSERT INTO %s VALUES (1, 2, 3, 4, 5, 6)", targetTable), 1); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (col1 BIGINT, col2 REAL, col3 DOUBLE, col4 INT, col5 INT, col6 REAL)", sourceTable))); + createTableForWrites("CREATE TABLE %s (col1 BIGINT, col2 REAL, col3 DOUBLE, col4 INT, col5 INT, col6 REAL)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s VALUES (2, 3, 4, 5, 6, 7)", sourceTable), 1); @@ -6701,11 +6715,11 @@ public void testMergeSubqueries() String targetTable = "merge_nation_target_" + randomNameSuffix(); String sourceTable = "merge_nation_source_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable))); + createTableForWrites("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", targetTable, Optional.of("nation_name")); assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('FRANCE', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('GERMANY', 'EUROPE')", targetTable), 3); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable))); + createTableForWrites("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR)", sourceTable, Optional.empty()); assertUpdate(format("INSERT INTO %s VALUES ('ALGERIA', 'AFRICA'), ('FRANCE', 'EUROPE'), ('EGYPT', 'MIDDLE EAST'), ('RUSSIA', 'EUROPE')", sourceTable), 4); @@ -6731,7 +6745,7 @@ public void testMergeNonNullableColumns() String targetTable = "merge_non_nullable_target_" + randomNameSuffix(); - assertUpdate(createTableForWrites(format("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR NOT NULL)", targetTable))); + createTableForWrites("CREATE TABLE %s (nation_name VARCHAR, region_name VARCHAR NOT NULL)", targetTable, Optional.of("nation_name")); assertUpdate(format("INSERT INTO %s (nation_name, region_name) VALUES ('FRANCE', 'EUROPE'), ('ALGERIA', 'AFRICA'), ('GERMANY', 'EUROPE')", targetTable), 3); @@ -6774,7 +6788,7 @@ public void testMergeAllColumnsReversed() skipTestUnless(hasBehavior(SUPPORTS_MERGE) && hasBehavior(SUPPORTS_NOT_NULL_CONSTRAINT)); String targetTable = "merge_update_columns_reversed_" + randomNameSuffix(); - assertUpdate(createTableForWrites("CREATE TABLE " + targetTable + " (a, b, c) AS VALUES (1, 2, 3)"), 1); + createTableForWrites("CREATE TABLE " + targetTable + " (a, b, c) AS VALUES (1, 2, 3)", targetTable, Optional.of("a"), OptionalInt.of(1)); assertUpdate( """ MERGE INTO %s t USING (VALUES(1)) AS s(a) ON (t.a = s.a) diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java index 39e7bcf91355..1d0b1307d3e1 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseFailureRecoveryTest.java @@ -231,10 +231,13 @@ protected void testDelete() @Test protected void testDeleteWithSubquery() { - testTableModification( + testNonSelect( + Optional.empty(), Optional.of("CREATE TABLE AS SELECT * FROM orders"), "DELETE FROM
WHERE custkey IN (SELECT custkey FROM customer WHERE nationkey = 1)", - Optional.of("DROP TABLE
")); + Optional.of("DROP TABLE
"), + true, + Optional.of("orderkey")); } @Test @@ -269,7 +272,8 @@ protected void testAnalyzeTable() @Test protected void testMerge() { - testTableModification( + testNonSelect( + Optional.empty(), Optional.of("CREATE TABLE
AS SELECT * FROM orders"), """ MERGE INTO
t @@ -280,7 +284,9 @@ protected void testMerge() WHEN MATCHED AND s.orderkey <= 1000 THEN DELETE """, - Optional.of("DROP TABLE
")); + Optional.of("DROP TABLE
"), + true, + Optional.of("orderkey")); } @Test @@ -329,6 +335,11 @@ protected void testTableModification(Optional session, Optional } protected void testNonSelect(Optional session, Optional setupQuery, String query, Optional cleanupQuery, boolean writesData) + { + testNonSelect(session, setupQuery, query, cleanupQuery, writesData, Optional.empty()); + } + + protected void testNonSelect(Optional session, Optional setupQuery, String query, Optional cleanupQuery, boolean writesData, Optional primaryKey) { if (writesData && !areWriteRetriesSupported()) { // if retries are not supported assert on that and skip actual failures simulation @@ -336,6 +347,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .failsDespiteRetries(failure -> failure.hasMessageMatching("This connector does not support query retries")) .cleansUpTemporaryTables(); return; @@ -346,6 +358,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) .at(boundaryCoordinatorStage()) .finishesSuccessfully() @@ -356,6 +369,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) .at(boundaryCoordinatorStage()) .failsAlways(failure -> failure.hasMessageContaining(FAILURE_INJECTION_MESSAGE)) @@ -367,6 +381,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) .at(rootStage()) .finishesSuccessfully() @@ -377,6 +392,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) .at(rootStage()) .failsAlways(failure -> failure.hasMessageContaining(FAILURE_INJECTION_MESSAGE)) @@ -387,6 +403,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)) .at(boundaryDistributedStage()) .failsWithoutRetries(failure -> failure.hasMessageContaining(FAILURE_INJECTION_MESSAGE)) @@ -396,6 +413,7 @@ protected void testNonSelect(Optional session, Optional setupQu assertThatQuery(query) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_MANAGEMENT_REQUEST_TIMEOUT) .at(boundaryDistributedStage()) .failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")) @@ -407,6 +425,7 @@ protected void testNonSelect(Optional session, Optional setupQu .withSession(session) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_GET_RESULTS_REQUEST_FAILURE) .at(boundaryDistributedStage()) .failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Error 500 Internal Server Error|Error closing remote buffer, expected 204 got 500")) @@ -416,6 +435,7 @@ protected void testNonSelect(Optional session, Optional setupQu assertThatQuery(query) .withSetupQuery(setupQuery) .withCleanupQuery(cleanupQuery) + .withPrimaryKey(primaryKey) .experiencing(TASK_GET_RESULTS_REQUEST_TIMEOUT) .at(boundaryDistributedStage()) .failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Encountered too many errors talking to a worker node|Error closing remote buffer")) @@ -475,6 +495,10 @@ protected boolean checkNoRemainingTmpTables() return true; } + protected void addPrimaryKeyForMergeTarget(Session session, String tableName, String primaryKey) + { + } + protected class FailureRecoveryAssert { private final String query; @@ -485,6 +509,7 @@ protected class FailureRecoveryAssert private Optional setup = Optional.empty(); private Optional cleanup = Optional.empty(); private Set queryIds = new HashSet<>(); + private Optional primaryKey = Optional.empty(); public FailureRecoveryAssert(String query) { @@ -510,6 +535,12 @@ public FailureRecoveryAssert withCleanupQuery(Optional query) return this; } + public FailureRecoveryAssert withPrimaryKey(Optional primaryKey) + { + this.primaryKey = requireNonNull(primaryKey, "primaryKey is null"); + return this; + } + public FailureRecoveryAssert experiencing(InjectedFailureType failureType) { return experiencing(failureType, Optional.empty()); @@ -572,6 +603,8 @@ private ExecutionResult execute(Session session, String query, Optional String tableName = "table_" + randomNameSuffix(); setup.ifPresent(sql -> getQueryRunner().execute(noRetries(session), resolveTableName(sql, tableName))); + primaryKey.ifPresent(key -> addPrimaryKeyForMergeTarget(session, tableName, key)); + MaterializedResultWithPlan resultWithPlan = null; RuntimeException failure = null; String queryId = null;