diff --git a/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java b/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java index db0c9785..098e010f 100644 --- a/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java +++ b/src/main/java/io/supertokens/storage/postgresql/QueryExecutorTemplate.java @@ -51,6 +51,18 @@ static int update(Start start, String QUERY, PreparedStatementValueSetter setter } } + static T update(Start start, String QUERY, PreparedStatementValueSetter setter, ResultSetValueExtractor mapper) + throws SQLException, StorageQueryException { + try (Connection con = ConnectionPool.getConnection(start)) { + try (PreparedStatement pst = con.prepareStatement(QUERY)) { + setter.setValues(pst); + try (ResultSet result = pst.executeQuery()) { + return mapper.extract(result); + } + } + } + } + static int update(Connection con, String QUERY, PreparedStatementValueSetter setter) throws SQLException, StorageQueryException { try (PreparedStatement pst = con.prepareStatement(QUERY)) { diff --git a/src/main/java/io/supertokens/storage/postgresql/Start.java b/src/main/java/io/supertokens/storage/postgresql/Start.java index ce0d27f2..d5a72796 100644 --- a/src/main/java/io/supertokens/storage/postgresql/Start.java +++ b/src/main/java/io/supertokens/storage/postgresql/Start.java @@ -3064,7 +3064,7 @@ public void addBulkImportUsers(AppIdentifier appIdentifier, List } @Override - public List getBulkImportUsers(AppIdentifier appIdentifier, @Nonnull Integer limit, @Nullable BulkImportUserStatus status, + public List getBulkImportUsers(AppIdentifier appIdentifier, @Nonnull Integer limit, @Nullable BULK_IMPORT_USER_STATUS status, @Nullable String bulkImportUserId, @Nullable Long createdAt) throws StorageQueryException { try { return BulkImportQueries.getBulkImportUsers(this, appIdentifier, limit, status, bulkImportUserId, createdAt); @@ -3074,7 +3074,7 @@ public List getBulkImportUsers(AppIdentifier appIdentifier, @Non } @Override - public void updateBulkImportUserStatus_Transaction(AppIdentifier appIdentifier, TransactionConnection con, @Nonnull String[] bulkImportUserIds, @Nonnull BulkImportUserStatus status) + public void updateBulkImportUserStatus_Transaction(AppIdentifier appIdentifier, TransactionConnection con, @Nonnull String[] bulkImportUserIds, @Nonnull BULK_IMPORT_USER_STATUS status) throws StorageQueryException { Connection sqlCon = (Connection) con.getConnection(); try { @@ -3083,4 +3083,13 @@ public void updateBulkImportUserStatus_Transaction(AppIdentifier appIdentifier, throw new StorageQueryException(e); } } + + @Override + public List deleteBulkImportUsers(AppIdentifier appIdentifier, @Nonnull String[] bulkImportUserIds) throws StorageQueryException { + try { + return BulkImportQueries.deleteBulkImportUsers(this, appIdentifier, bulkImportUserIds); + } catch (SQLException e) { + throw new StorageQueryException(e); + } + } } diff --git a/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java b/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java index 4c77b1a1..2f3b5139 100644 --- a/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java +++ b/src/main/java/io/supertokens/storage/postgresql/queries/BulkImportQueries.java @@ -29,7 +29,7 @@ import javax.annotation.Nullable; import io.supertokens.pluginInterface.RowMapper; -import io.supertokens.pluginInterface.bulkimport.BulkImportStorage.BulkImportUserStatus; +import io.supertokens.pluginInterface.bulkimport.BulkImportStorage.BULK_IMPORT_USER_STATUS; import io.supertokens.pluginInterface.bulkimport.BulkImportUser; import io.supertokens.pluginInterface.exceptions.StorageQueryException; import io.supertokens.pluginInterface.multitenancy.AppIdentifier; @@ -87,12 +87,12 @@ public static void insertBulkImportUsers(Start start, AppIdentifier appIdentifie for (BulkImportUser user : users) { pst.setString(parameterIndex++, user.id); pst.setString(parameterIndex++, appIdentifier.getAppId()); - pst.setString(parameterIndex++, user.toString()); + pst.setString(parameterIndex++, user.toRawDataForDbStorage()); } }); } - public static void updateBulkImportUserStatus_Transaction(Start start, Connection con, AppIdentifier appIdentifier, @Nonnull String[] bulkImportUserIds, @Nonnull BulkImportUserStatus status) + public static void updateBulkImportUserStatus_Transaction(Start start, Connection con, AppIdentifier appIdentifier, @Nonnull String[] bulkImportUserIds, @Nonnull BULK_IMPORT_USER_STATUS status) throws SQLException, StorageQueryException { if (bulkImportUserIds.length == 0) { return; @@ -125,7 +125,7 @@ public static void updateBulkImportUserStatus_Transaction(Start start, Connectio }); } - public static List getBulkImportUsers(Start start, AppIdentifier appIdentifier, @Nonnull Integer limit, @Nullable BulkImportUserStatus status, + public static List getBulkImportUsers(Start start, AppIdentifier appIdentifier, @Nonnull Integer limit, @Nullable BULK_IMPORT_USER_STATUS status, @Nullable String bulkImportUserId, @Nullable Long createdAt) throws SQLException, StorageQueryException { @@ -168,6 +168,43 @@ public static List getBulkImportUsers(Start start, AppIdentifier }); } + public static List deleteBulkImportUsers(Start start, AppIdentifier appIdentifier, @Nonnull String[] bulkImportUserIds) throws SQLException, StorageQueryException { + if (bulkImportUserIds.length == 0) { + return new ArrayList<>(); + } + + String baseQuery = "DELETE FROM " + Config.getConfig(start).getBulkImportUsersTable(); + StringBuilder queryBuilder = new StringBuilder(baseQuery); + + List parameters = new ArrayList<>(); + + queryBuilder.append(" WHERE app_id = ?"); + parameters.add(appIdentifier.getAppId()); + + queryBuilder.append(" AND id IN ("); + for (int i = 0; i < bulkImportUserIds.length; i++) { + if (i != 0) { + queryBuilder.append(", "); + } + queryBuilder.append("?"); + parameters.add(bulkImportUserIds[i]); + } + queryBuilder.append(") RETURNING id"); + + String query = queryBuilder.toString(); + + return update(start, query, pst -> { + for (int i = 0; i < parameters.size(); i++) { + pst.setObject(i + 1, parameters.get(i)); + } + }, result -> { + List deletedIds = new ArrayList<>(); + while (result.next()) { + deletedIds.add(result.getString("id")); + } + return deletedIds; + }); + } private static class BulkImportUserRowMapper implements RowMapper { private static final BulkImportUserRowMapper INSTANCE = new BulkImportUserRowMapper(); @@ -180,8 +217,8 @@ private static BulkImportUserRowMapper getInstance() { @Override public BulkImportUser map(ResultSet result) throws Exception { - return BulkImportUser.fromDBJson(result.getString("id"), result.getString("raw_data"), - BulkImportUserStatus.valueOf(result.getString("status")), + return BulkImportUser.fromRawDataFromDbStorage(result.getString("id"), result.getString("raw_data"), + BULK_IMPORT_USER_STATUS.valueOf(result.getString("status")), result.getLong("created_at"), result.getLong("updated_at")); } }