Skip to content

Commit

Permalink
Implement different db users to seperate the tenants from each other
Browse files Browse the repository at this point in the history
  • Loading branch information
RandomTannenbaum committed Nov 13, 2024
1 parent 3cee603 commit b9be875
Show file tree
Hide file tree
Showing 18 changed files with 780 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import jakarta.persistence.EntityNotFoundException;
import org.flywaydb.core.Flyway;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

Expand All @@ -10,6 +12,8 @@ public class FlywayMultitenantMigrationInitializer {
private final TenantConfigProviderInterface tenantConfigProvider;
private final String[] scriptLocations;

private static final Logger logger = LoggerFactory.getLogger(FlywayMultitenantMigrationInitializer.class);

public FlywayMultitenantMigrationInitializer(TenantConfigProviderInterface tenantConfigProvider,
final @Value("${spring.flyway.locations}") String[] scriptLocations) {
this.tenantConfigProvider = tenantConfigProvider;
Expand All @@ -20,9 +24,11 @@ public void migrateFlyway() {
this.tenantConfigProvider.getTenantConfigs().forEach((tenantConfig) -> {
TenantConfigProvider.DataSourceConfig dataSourceConfig = this.tenantConfigProvider
.getTenantConfigById(tenantConfig.tenantId())
.map(TenantConfigProvider.TenantConfig::dataSourceConfig).orElseThrow(
.map(TenantConfigProvider.TenantConfig::dataSourceConfigFlyway).orElseThrow(
() -> new EntityNotFoundException("Cannot find tenant for configuring flyway migration"));

logUsedHibernateConfig(dataSourceConfig);

Flyway tenantSchemaFlyway = Flyway.configure() //
.dataSource(dataSourceConfig.url(), dataSourceConfig.name(), dataSourceConfig.password()) //
.locations(scriptLocations) //
Expand All @@ -32,6 +38,9 @@ public void migrateFlyway() {

tenantSchemaFlyway.migrate();
});
}

private void logUsedHibernateConfig(TenantConfigProvider.DataSourceConfig dataSourceConfig) {
logger.error("use DbConfig: user={}", dataSourceConfig.name());
}
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
package ch.puzzle.okr.multitenancy;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.ConfigurableEnvironment;

import java.util.Properties;

/**
* Reads the (not tenant specific) hibernate configuration form the "hibernate.x" properties in the
* applicationX.properties file. It then caches the configuration as DbConfig object. The data from the DbConfig object
* is used by the SchemaMultiTenantConnectionProvider via getHibernateConfig() and getHibernateConfig(tenantId).
*
* <pre>
* getHibernateConfig() returns the cached DbConfig as properties.
* </pre>
*
* <pre>
* getHibernateConfig(tenantId) patches the DbConfig data with tenant specific data (from
* TenantConfigProvider) and returns the patched data as properties
* </pre>
*/
public class HibernateContext {
public static final String HIBERNATE_CONNECTION_URL = "hibernate.connection.url";
public static final String HIBERNATE_CONNECTION_USERNAME = "hibernate.connection.username";
Expand All @@ -14,6 +30,8 @@ public class HibernateContext {
public static String SPRING_DATASOURCE_USERNAME = "spring.datasource.username";
public static String SPRING_DATASOURCE_PASSWORD = "spring.datasource.password";

private static final Logger logger = LoggerFactory.getLogger(HibernateContext.class);

public record DbConfig(String url, String username, String password, String multiTenancy) {

public boolean isValid() {
Expand All @@ -29,20 +47,22 @@ private boolean hasEmptyValues() {
}
}

// general (not tenant specific) hibernate config
private static DbConfig cachedHibernateConfig;

public static void extractAndSetHibernateConfig(ConfigurableEnvironment environment) {
DbConfig dbConfig = extractHibernateConfig(environment);
setHibernateConfig(dbConfig);
logUsedHibernateConfig(dbConfig);
}

public static void setHibernateConfig(DbConfig dbConfig) {
if (dbConfig == null || !dbConfig.isValid()) {
throw new RuntimeException("Invalid hibernate configuration " + dbConfig);
}
cachedHibernateConfig = dbConfig;
}

public static void extractAndSetHibernateConfig(ConfigurableEnvironment environment) {
DbConfig dbConfig = extractHibernateConfig(environment);
HibernateContext.setHibernateConfig(dbConfig);
}

private static DbConfig extractHibernateConfig(ConfigurableEnvironment environment) {
String url = environment.getProperty(HibernateContext.HIBERNATE_CONNECTION_URL);
String username = environment.getProperty(HibernateContext.HIBERNATE_CONNECTION_USERNAME);
Expand All @@ -60,7 +80,9 @@ public static Properties getHibernateConfig() {
if (cachedHibernateConfig == null) {
throw new RuntimeException("No cached hibernate configuration found");
}
return getConfigAsProperties(cachedHibernateConfig);
var config = getConfigAsProperties(cachedHibernateConfig);
logUsedHibernateConfig(config);
return config;
}

private static Properties getConfigAsProperties(DbConfig dbConfig) {
Expand All @@ -74,4 +96,48 @@ private static Properties getConfigAsProperties(DbConfig dbConfig) {
properties.put(HibernateContext.SPRING_DATASOURCE_PASSWORD, dbConfig.password());
return properties;
}

public static Properties getHibernateConfig(String tenantIdentifier) {
if (cachedHibernateConfig == null) {
throw new RuntimeException("No cached hibernate configuration found (for tenant " + tenantIdentifier + ")");
}
var config = getConfigAsPropertiesAndPatch(cachedHibernateConfig, tenantIdentifier);
logUsedHibernateConfig(tenantIdentifier, config);
return config;
}

private static Properties getConfigAsPropertiesAndPatch(DbConfig dbConfig, String tenantIdentifier) {
Properties properties = getConfigAsProperties(dbConfig);
return patchConfigAppForTenant(properties, tenantIdentifier);
}

private static Properties patchConfigAppForTenant(Properties properties, String tenantIdentifier) {
TenantConfigProvider.TenantConfig cachedTenantConfig = TenantConfigProvider
.getCachedTenantConfig(tenantIdentifier);
if (cachedTenantConfig == null) {
throw new RuntimeException("No cached tenant configuration found (for tenant " + tenantIdentifier + ")");
}

TenantConfigProvider.DataSourceConfig dataSourceConfigApp = cachedTenantConfig.dataSourceConfigApp();
properties.put(HibernateContext.HIBERNATE_CONNECTION_USERNAME, dataSourceConfigApp.name());
properties.put(HibernateContext.HIBERNATE_CONNECTION_PASSWORD, dataSourceConfigApp.password());
properties.put(HibernateContext.SPRING_DATASOURCE_USERNAME, dataSourceConfigApp.name());
properties.put(HibernateContext.SPRING_DATASOURCE_PASSWORD, dataSourceConfigApp.password());
return properties;
}

private static void logUsedHibernateConfig(DbConfig hibernateConfig) {
logger.error("set DbConfig: user={}", hibernateConfig.username());
}

private static void logUsedHibernateConfig(Properties hibernateConfig) {
logger.error("use DbConfig: user={}",
hibernateConfig.getProperty(HibernateContext.HIBERNATE_CONNECTION_USERNAME)); //
}

private static void logUsedHibernateConfig(String tenantId, Properties hibernateConfig) {
logger.error("use DbConfig: tenant={} user={}", tenantId,
hibernateConfig.getProperty(HibernateContext.HIBERNATE_CONNECTION_USERNAME));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,35 @@

import static ch.puzzle.okr.multitenancy.TenantContext.DEFAULT_TENANT_ID;

/**
* The central piece of code of multitenancy.
*
* <pre>
* getConnection(tenantId) sets in each tenant request the specific db schema for the
* tenant. This guarantees that each tenant always works in its own DB schema.
*
* getConnection(tenantId) -> Connection calls in the abstract super class the
* getConnection(tenantId) -> Connection which calls the abstract
* selectConnectionProvider(tenantIdentifier) -> ConnectionProvider which is implemented
* in SchemaMultiTenantConnectionProvider.
* </pre>
*
* <pre>
* Some coding details:
*
* selectConnectionProvider(tenantId) -> ConnectionProvider returns for a tenant a
* ConnectionProvider. It first checks if the ConnectionProvider for the tenant is already
* cached (in connectionProviderMap). If the ConnectionProvider is cached, it returns it.
* Otherwise it creates a ConnectionProvider for the tenant, cache it and return it.
*
* To create a ConnectionProvider for the tenant, it tries to load the configuration from
* the hibernate properties. For this it uses 2 methods of HibernateContext:
* getHibernateConfig() if the tenant is the DEFAULT_TENANT_ID (public) and
* getHibernateConfig(tenantId) for all other tenants. With this information its then
* possible to create and cache a ConnectionProvider for the tenant. If no matching
* hibernate properties are found, then an exception is thrown.
* </pre>
*/
public class SchemaMultiTenantConnectionProvider extends AbstractMultiTenantConnectionProvider<String> {

private static final Logger logger = LoggerFactory.getLogger(SchemaMultiTenantConnectionProvider.class);
Expand All @@ -31,15 +60,15 @@ public Connection getConnection(String tenantIdentifier) throws SQLException {
return getConnection(tenantIdentifier, connection);
}

protected Connection getConnection(String tenantIdentifier, Connection connection) throws SQLException {
Connection getConnection(String tenantIdentifier, Connection connection) throws SQLException {
String schema = convertTenantIdToSchemaName(tenantIdentifier);
logger.debug("Setting schema to {}", schema);

connection.createStatement().execute(String.format("SET SCHEMA '%s';", schema));
return connection;
}

private String convertTenantIdToSchemaName(String tenantIdentifier) {
String convertTenantIdToSchemaName(String tenantIdentifier) {
return Objects.equals(tenantIdentifier, DEFAULT_TENANT_ID) ? tenantIdentifier
: MessageFormat.format("okr_{0}", tenantIdentifier);
}
Expand All @@ -54,13 +83,13 @@ protected ConnectionProvider selectConnectionProvider(String tenantIdentifier) {
return getConnectionProvider(tenantIdentifier);
}

protected ConnectionProvider getConnectionProvider(String tenantIdentifier) {
ConnectionProvider getConnectionProvider(String tenantIdentifier) {
return Optional.ofNullable(tenantIdentifier) //
.map(connectionProviderMap::get) //
.orElseGet(() -> createNewConnectionProvider(tenantIdentifier));
.orElseGet(() -> createAndCacheNewConnectionProvider(tenantIdentifier));
}

private ConnectionProvider createNewConnectionProvider(String tenantIdentifier) {
private ConnectionProvider createAndCacheNewConnectionProvider(String tenantIdentifier) {
return Optional.ofNullable(tenantIdentifier) //
.map(this::createConnectionProvider) //
.map(connectionProvider -> {
Expand All @@ -78,29 +107,25 @@ private ConnectionProvider createConnectionProvider(String tenantIdentifier) {
.orElse(null);
}

protected Properties getHibernatePropertiesForTenantIdentifier(String tenantIdentifier) {
Properties properties = getHibernateProperties();
if (properties == null || properties.isEmpty()) {
throw new RuntimeException("Cannot load hibernate properties from application.properties)");
Properties getHibernatePropertiesForTenantIdentifier(String tenantIdentifier) {
Properties properties = getHibernateProperties(tenantIdentifier);
if (properties.isEmpty()) {
throw new RuntimeException("Cannot load hibernate properties from application.properties");
}
if (!Objects.equals(tenantIdentifier, DEFAULT_TENANT_ID)) {
properties.put(AvailableSettings.DEFAULT_SCHEMA, MessageFormat.format("okr_{0}", tenantIdentifier));
}
return properties;
}

private ConnectionProvider initConnectionProvider(Properties hibernateProperties) {
ConnectionProvider initConnectionProvider(Properties hibernateProperties) {
Map<String, Object> configProperties = convertPropertiesToMap(hibernateProperties);
DriverManagerConnectionProviderImpl connectionProvider = getDriverManagerConnectionProviderImpl();
DriverManagerConnectionProviderImpl connectionProvider = new DriverManagerConnectionProviderImpl();
connectionProvider.configure(configProperties);
return connectionProvider;
}

protected DriverManagerConnectionProviderImpl getDriverManagerConnectionProviderImpl() {
return new DriverManagerConnectionProviderImpl();
}

private Map<String, Object> convertPropertiesToMap(Properties properties) {
Map<String, Object> convertPropertiesToMap(Properties properties) {
Map<String, Object> configProperties = new HashMap<>();
for (String key : properties.stringPropertyNames()) {
String value = properties.getProperty(key);
Expand All @@ -109,7 +134,10 @@ private Map<String, Object> convertPropertiesToMap(Properties properties) {
return configProperties;
}

protected Properties getHibernateProperties() {
return HibernateContext.getHibernateConfig();
private Properties getHibernateProperties(String tenantIdentifier) {
if (tenantIdentifier.equals(DEFAULT_TENANT_ID)) {
return HibernateContext.getHibernateConfig();
}
return HibernateContext.getHibernateConfig(tenantIdentifier);
}
}
Original file line number Diff line number Diff line change
@@ -1,27 +1,62 @@
package ch.puzzle.okr.multitenancy;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;

import java.text.MessageFormat;
import java.util.*;

/**
* Reads the configuration of the tenants (as TenantConfig objects) from the applicationX.properties and caches each
* TenantConfig in the TenantConfigs class.
*/
@Component
public class TenantConfigProvider implements TenantConfigProviderInterface {
private static final String EMAIL_DELIMITER = ",";
private final Map<String, TenantConfig> tenantConfigs = new HashMap<>();
private final Environment env;

private enum DbType {
bootstrap, app, fly
}

private static final Logger logger = LoggerFactory.getLogger(TenantConfigProvider.class);

public TenantConfigProvider(final @Value("${okr.tenant-ids}") String[] tenantIds, Environment env) {
this.env = env;
for (String tenantId : tenantIds) {
OauthConfig c = readOauthConfig(tenantId);
tenantConfigs.put(tenantId,
createTenantConfig(c.jwkSetUri(), c.frontendClientIssuerUrl(), c.frontendClientId(), tenantId));
TenantConfig tenantConfig = createTenantConfig(c.jwkSetUri(), c.frontendClientIssuerUrl(),
c.frontendClientId(), tenantId);

tenantConfigs.put(tenantId, tenantConfig);
cacheTenantConfig(tenantId, tenantConfig); // cache tenantConfig for Hibernate connections
}
}

private void cacheTenantConfig(String tenantId, TenantConfig tenantConfig) {
TenantConfigs.add(tenantId, tenantConfig);
logCachingTenantConfig(tenantId, tenantConfig);
}

private void logCachingTenantConfig(String tenantId, TenantConfig tenantConfig) {
logger.error("cache TenantConfig: tenantId={}, users={}", //
tenantId, //
tenantConfig.dataSourceConfigFlyway().name() + " | " + tenantConfig.dataSourceConfigApp().name());
}

public static TenantConfigProvider.TenantConfig getCachedTenantConfig(String tenantId) {
return TenantConfigs.get(tenantId);
}

// for tests
public static void clearTenantConfigsCache() {
TenantConfigs.clear();
}

private OauthConfig readOauthConfig(String tenantId) {
return new OauthConfig(
env.getProperty(MessageFormat.format("okr.tenants.{0}.security.oauth2.resourceserver.jwt.jwk-set-uri",
Expand All @@ -32,8 +67,11 @@ private OauthConfig readOauthConfig(String tenantId) {

private TenantConfig createTenantConfig(String jwkSetUriTemplate, String frontendClientIssuerUrl,
String frontendClientId, String tenantId) {
return new TenantConfig(tenantId, getOkrChampionEmailsFromTenant(tenantId), jwkSetUriTemplate,
frontendClientIssuerUrl, frontendClientId, this.readDataSourceConfig(tenantId));

return new TenantConfig(tenantId, getOkrChampionEmailsFromTenant(tenantId), jwkSetUriTemplate, //
frontendClientIssuerUrl, frontendClientId, //
this.readDataSourceConfigFlyway(tenantId), //
this.readDataSourceConfigApp(tenantId));
}

private String[] getOkrChampionEmailsFromTenant(String tenantId) {
Expand All @@ -45,11 +83,19 @@ public List<TenantConfig> getTenantConfigs() {
return this.tenantConfigs.values().stream().toList();
}

private DataSourceConfig readDataSourceConfig(String tenantId) {
private DataSourceConfig readDataSourceConfigFlyway(String tenantId) {
return readDataSourceConfig(tenantId, DbType.fly);
}

private DataSourceConfig readDataSourceConfigApp(String tenantId) {
return readDataSourceConfig(tenantId, DbType.app);
}

private DataSourceConfig readDataSourceConfig(String tenantId, DbType dbType) {
return new DataSourceConfig(env.getProperty("okr.datasource.driver-class-name"),
env.getProperty(MessageFormat.format("okr.tenants.{0}.datasource.url", tenantId)),
env.getProperty(MessageFormat.format("okr.tenants.{0}.datasource.username", tenantId)),
env.getProperty(MessageFormat.format("okr.tenants.{0}.datasource.password", tenantId)),
env.getProperty(MessageFormat.format("okr.tenants.{0}.datasource.username." + dbType, tenantId)),
env.getProperty(MessageFormat.format("okr.tenants.{0}.datasource.password." + dbType, tenantId)),
env.getProperty(MessageFormat.format("okr.tenants.{0}.datasource.schema", tenantId)));
}

Expand All @@ -62,7 +108,7 @@ public Optional<String> getJwkSetUri(String tenantId) {
}

public record TenantConfig(String tenantId, String[] okrChampionEmails, String jwkSetUri, String issuerUrl,
String clientId, DataSourceConfig dataSourceConfig) {
String clientId, DataSourceConfig dataSourceConfigFlyway, DataSourceConfig dataSourceConfigApp) {
}

public record DataSourceConfig(String driverClassName, String url, String name, String password, String schema) {
Expand Down
Loading

0 comments on commit b9be875

Please sign in to comment.