Skip to content

Commit

Permalink
Add CrudMethodMetadata to support ReadPreference annotations on overr…
Browse files Browse the repository at this point in the history
…idden base repository methods.

See: #2971
Original Pull Request: #4503
  • Loading branch information
mp911de authored and christophstrobl committed Oct 12, 2023
1 parent 74b07e5 commit 5d25645
Show file tree
Hide file tree
Showing 11 changed files with 632 additions and 150 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.support;

import java.lang.reflect.Method;
import java.util.Optional;

import com.mongodb.ReadPreference;

/**
* Interface to abstract {@link CrudMethodMetadata} that provide the {@link ReadPreference} to be used for query
* execution.
*
* @author Mark Paluch
* @since 4.2
*/
public interface CrudMethodMetadata {

/**
* Returns the {@link ReadPreference} to be used.
*
* @return the {@link ReadPreference} to be used.
*/
Optional<ReadPreference> getReadPreference();

/**
* Returns the {@link Method} that this metadata applies to.
*
* @return the {@link Method} that this metadata applies to.
*/
Method getMethod();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.repository.support;

import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
import org.springframework.aop.TargetSource;
import org.springframework.aop.framework.ProxyFactory;
import org.springframework.beans.factory.BeanClassLoaderAware;
import org.springframework.core.NamedThreadLocal;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.data.repository.core.RepositoryInformation;
import org.springframework.data.repository.core.support.RepositoryProxyPostProcessor;
import org.springframework.lang.Nullable;
import org.springframework.transaction.support.TransactionSynchronizationManager;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

import com.mongodb.ReadPreference;

/**
* {@link RepositoryProxyPostProcessor} that sets up interceptors to read metadata information from the invoked method.
* This is necessary to allow redeclaration of CRUD methods in repository interfaces and configure read preference
* information or query hints on them.
*
* @author Mark Paluch
*/
class CrudMethodMetadataPostProcessor implements RepositoryProxyPostProcessor, BeanClassLoaderAware {

private @Nullable ClassLoader classLoader = ClassUtils.getDefaultClassLoader();

@Override
public void setBeanClassLoader(ClassLoader classLoader) {
this.classLoader = classLoader;
}

@Override
public void postProcess(ProxyFactory factory, RepositoryInformation repositoryInformation) {
factory.addAdvice(new CrudMethodMetadataPopulatingMethodInterceptor(repositoryInformation));
}

/**
* Returns a {@link CrudMethodMetadata} proxy that will lookup the actual target object by obtaining a thread bound
* instance from the {@link TransactionSynchronizationManager} later.
*/
CrudMethodMetadata getCrudMethodMetadata() {

ProxyFactory factory = new ProxyFactory();

factory.addInterface(CrudMethodMetadata.class);
factory.setTargetSource(new ThreadBoundTargetSource());

return (CrudMethodMetadata) factory.getProxy(this.classLoader);
}

/**
* {@link MethodInterceptor} to build and cache {@link DefaultCrudMethodMetadata} instances for the invoked methods.
* Will bind the found information to a {@link TransactionSynchronizationManager} for later lookup.
*
* @see DefaultCrudMethodMetadata
*/
static class CrudMethodMetadataPopulatingMethodInterceptor implements MethodInterceptor {

private static final ThreadLocal<MethodInvocation> currentInvocation = new NamedThreadLocal<>(
"Current AOP method invocation");

private final ConcurrentMap<Method, CrudMethodMetadata> metadataCache = new ConcurrentHashMap<>();
private final Set<Method> implementations = new HashSet<>();

CrudMethodMetadataPopulatingMethodInterceptor(RepositoryInformation repositoryInformation) {

ReflectionUtils.doWithMethods(repositoryInformation.getRepositoryInterface(), implementations::add,
method -> !repositoryInformation.isQueryMethod(method));
}

/**
* Return the AOP Alliance {@link MethodInvocation} object associated with the current invocation.
*
* @return the invocation object associated with the current invocation.
* @throws IllegalStateException if there is no AOP invocation in progress, or if the
* {@link CrudMethodMetadataPopulatingMethodInterceptor} was not added to this interceptor chain.
*/
static MethodInvocation currentInvocation() throws IllegalStateException {

MethodInvocation mi = currentInvocation.get();

if (mi == null)
throw new IllegalStateException(
"No MethodInvocation found: Check that an AOP invocation is in progress, and that the "
+ "CrudMethodMetadataPopulatingMethodInterceptor is upfront in the interceptor chain.");
return mi;
}

@Override
public Object invoke(MethodInvocation invocation) throws Throwable {

Method method = invocation.getMethod();

if (!implementations.contains(method)) {
return invocation.proceed();
}

MethodInvocation oldInvocation = currentInvocation.get();
currentInvocation.set(invocation);

try {

CrudMethodMetadata metadata = (CrudMethodMetadata) TransactionSynchronizationManager.getResource(method);

if (metadata != null) {
return invocation.proceed();
}

CrudMethodMetadata methodMetadata = metadataCache.get(method);

if (methodMetadata == null) {

methodMetadata = new DefaultCrudMethodMetadata(method);
CrudMethodMetadata tmp = metadataCache.putIfAbsent(method, methodMetadata);

if (tmp != null) {
methodMetadata = tmp;
}
}

TransactionSynchronizationManager.bindResource(method, methodMetadata);

try {
return invocation.proceed();
} finally {
TransactionSynchronizationManager.unbindResource(method);
}
} finally {
currentInvocation.set(oldInvocation);
}
}
}

/**
* Default implementation of {@link CrudMethodMetadata} that will inspect the backing method for annotations.
*/
static class DefaultCrudMethodMetadata implements CrudMethodMetadata {

private final Optional<ReadPreference> readPreference;
private final Method method;

/**
* Creates a new {@link DefaultCrudMethodMetadata} for the given {@link Method}.
*
* @param method must not be {@literal null}.
*/
DefaultCrudMethodMetadata(Method method) {

Assert.notNull(method, "Method must not be null");

this.readPreference = findReadPreference(method);
this.method = method;
}

private Optional<ReadPreference> findReadPreference(Method method) {

org.springframework.data.mongodb.repository.ReadPreference preference = AnnotatedElementUtils
.findMergedAnnotation(method, org.springframework.data.mongodb.repository.ReadPreference.class);

if (preference == null) {

preference = AnnotatedElementUtils.findMergedAnnotation(method.getDeclaringClass(),
org.springframework.data.mongodb.repository.ReadPreference.class);
}

if (preference == null) {
return Optional.empty();
}

return Optional.of(com.mongodb.ReadPreference.valueOf(preference.value()));

}

@Override
public Optional<ReadPreference> getReadPreference() {
return readPreference;
}

@Override
public Method getMethod() {
return method;
}
}

private static class ThreadBoundTargetSource implements TargetSource {

@Override
public Class<?> getTargetClass() {
return CrudMethodMetadata.class;
}

@Override
public boolean isStatic() {
return false;
}

@Override
public Object getTarget() {

MethodInvocation invocation = CrudMethodMetadataPopulatingMethodInterceptor.currentInvocation();
return TransactionSynchronizationManager.getResource(invocation.getMethod());
}

@Override
public void releaseTarget(Object target) {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public String getCollectionName() {
}

public String getIdAttribute() {
return entityMetadata.getRequiredIdProperty().getName();
return entityMetadata.hasIdProperty() ? entityMetadata.getRequiredIdProperty().getName() : "_id";
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public class MongoRepositoryFactory extends RepositoryFactorySupport {

private static final SpelExpressionParser EXPRESSION_PARSER = new SpelExpressionParser();

private final CrudMethodMetadataPostProcessor crudMethodMetadataPostProcessor = new CrudMethodMetadataPostProcessor();
private final MongoOperations operations;
private final MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;

Expand All @@ -75,6 +76,15 @@ public MongoRepositoryFactory(MongoOperations mongoOperations) {

this.operations = mongoOperations;
this.mappingContext = mongoOperations.getConverter().getMappingContext();

addRepositoryProxyPostProcessor(crudMethodMetadataPostProcessor);
}

@Override
public void setBeanClassLoader(ClassLoader classLoader) {

super.setBeanClassLoader(classLoader);
crudMethodMetadataPostProcessor.setBeanClassLoader(classLoader);
}

@Override
Expand Down Expand Up @@ -127,7 +137,13 @@ protected Object getTargetRepository(RepositoryInformation information) {

MongoEntityInformation<?, Serializable> entityInformation = getEntityInformation(information.getDomainType(),
information);
return getTargetRepositoryViaReflection(information, information, entityInformation, operations);
Object targetRepository = getTargetRepositoryViaReflection(information, entityInformation, operations);

if (targetRepository instanceof SimpleMongoRepository<?, ?> repository) {
repository.setRepositoryMethodMetadata(crudMethodMetadataPostProcessor.getCrudMethodMetadata());
}

return targetRepository;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public class ReactiveMongoRepositoryFactory extends ReactiveRepositoryFactorySup

private static final SpelExpressionParser EXPRESSION_PARSER = new SpelExpressionParser();

private final CrudMethodMetadataPostProcessor crudMethodMetadataPostProcessor = new CrudMethodMetadataPostProcessor();
private final ReactiveMongoOperations operations;
private final MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;

Expand All @@ -76,7 +77,16 @@ public ReactiveMongoRepositoryFactory(ReactiveMongoOperations mongoOperations) {

this.operations = mongoOperations;
this.mappingContext = mongoOperations.getConverter().getMappingContext();

setEvaluationContextProvider(ReactiveQueryMethodEvaluationContextProvider.DEFAULT);
addRepositoryProxyPostProcessor(crudMethodMetadataPostProcessor);
}

@Override
public void setBeanClassLoader(ClassLoader classLoader) {

super.setBeanClassLoader(classLoader);
crudMethodMetadataPostProcessor.setBeanClassLoader(classLoader);
}

@Override
Expand Down Expand Up @@ -114,7 +124,13 @@ protected Object getTargetRepository(RepositoryInformation information) {

MongoEntityInformation<?, Serializable> entityInformation = getEntityInformation(information.getDomainType(),
information);
return getTargetRepositoryViaReflection(information, information, entityInformation, operations);
Object targetRepository = getTargetRepositoryViaReflection(information, entityInformation, operations);

if (targetRepository instanceof SimpleReactiveMongoRepository<?, ?> repository) {
repository.setRepositoryMethodMetadata(crudMethodMetadataPostProcessor.getCrudMethodMetadata());
}

return targetRepository;
}

@Override
Expand Down
Loading

0 comments on commit 5d25645

Please sign in to comment.