Skip to content

Commit

Permalink
Refactor to fix bug that happens during parallel initialization of Re…
Browse files Browse the repository at this point in the history
…flectionCache. Fixes #47
  • Loading branch information
agoston committed Apr 24, 2021
1 parent 128ec3b commit c47d562
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<artifactId>spring-data-mongodb-encrypt</artifactId>
<packaging>jar</packaging>
<name>spring-data-mongodb-encrypt</name>
<version>2.6.1</version>
<version>2.6.2</version>
<description>High performance, per-field encryption for spring-data-mongodb</description>
<url>https://github.com/agoston/spring-data-mongodb-encrypt</url>

Expand Down
77 changes: 51 additions & 26 deletions src/main/java/com/bol/reflection/ReflectionCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Modifier;
Expand All @@ -17,21 +16,30 @@ public class ReflectionCache {

private static final Logger LOG = LoggerFactory.getLogger(ReflectionCache.class);

private Map<Class, List<Node>> reflectionCache = new ConcurrentHashMap<>();
private ConcurrentHashMap<Class, List<Node>> reflectionCache = new ConcurrentHashMap<>();

// used by CachedEncryptionEventListener to gather metadata of a class and all it fields, recursively.
public List<Node> reflectRecursive(Class objectClass) {
List<Node> result = reflectionCache.get(objectClass);
if (result != null) {
LOG.trace("cyclic reference found; {} is already mapped", objectClass.getName());
return result;
List<Node> nodes = reflectionCache.get(objectClass);
if (nodes != null) return nodes;

synchronized (this) {
return buildRecursive(objectClass, new HashMap<>());
}
}

// building is necessary to avoid putting half-processed data in `reflectionCache` (where it would be returned to other threads)
private List<Node> buildRecursive(Class objectClass, HashMap<Class, List<Node>> building) {
if (isPrimitive(objectClass)) return Collections.emptyList();

List<Node> processed = reflectionCache.get(objectClass);
if (processed != null) return processed;

// java primitive type; ignore
if (ClassUtils.isPrimitiveOrWrapper(objectClass)) return Collections.emptyList();
List<Node> processing = building.get(objectClass);
if (processing != null) return processing;

List<Node> nodes = new ArrayList<>();
reflectionCache.put(objectClass, nodes);
building.put(objectClass, nodes);

ReflectionUtils.doWithFields(objectClass, field -> {
String fieldName = field.getName();
Expand All @@ -49,16 +57,16 @@ public List<Node> reflectRecursive(Class objectClass) {
Type fieldGenericType = field.getGenericType();

if (Collection.class.isAssignableFrom(fieldType)) {
List<Node> children = processParameterizedTypes(fieldGenericType);
List<Node> children = processParameterizedTypes(fieldGenericType, building);
if (!children.isEmpty()) nodes.add(new Node(fieldName, documentName, unwrap(children), Node.Type.LIST, field));

} else if (Map.class.isAssignableFrom(fieldType)) {
List<Node> children = processParameterizedTypes(fieldGenericType);
List<Node> children = processParameterizedTypes(fieldGenericType, building);
if (!children.isEmpty()) nodes.add(new Node(fieldName, documentName, unwrap(children), Node.Type.MAP, field));

} else {
// descending into sub-documents
List<Node> children = reflectRecursive(fieldType);
List<Node> children = buildRecursive(fieldType, building);
if (!children.isEmpty()) nodes.add(new Node(fieldName, documentName, children, Node.Type.DOCUMENT, field));
}
}
Expand All @@ -68,23 +76,21 @@ public List<Node> reflectRecursive(Class objectClass) {
}
});

reflectionCache.put(objectClass, nodes);

return nodes;
}

// used by ReflectionEncryptionEventListener to map a single Document
// FIXME: this is a slimmed down copy-paste of reflectRecursive(); find a way to bring Cached and Reflective listener closer together!
public List<Node> reflectSingle(Class objectClass) {
List<Node> result = reflectionCache.get(objectClass);
if (result != null) {
LOG.trace("cyclic reference found; {} is already mapped", objectClass.getName());
return result;
}
return reflectionCache.computeIfAbsent(objectClass, this::buildSingle);
}

// java primitive type; ignore
if (objectClass.getPackage().getName().equals("java.lang")) return Collections.emptyList();
// FIXME: this is a slimmed down copy-paste of buildRecursive(); find a way to bring Cached and Reflective listener closer together!
private List<Node> buildSingle(Class objectClass) {
if (isPrimitive(objectClass)) return Collections.emptyList();

List<Node> nodes = new ArrayList<>();
reflectionCache.put(objectClass, nodes);

ReflectionUtils.doWithFields(objectClass, field -> {
String fieldName = field.getName();
Expand All @@ -99,7 +105,6 @@ public List<Node> reflectSingle(Class objectClass) {

} else {
Class<?> fieldType = field.getType();
Type fieldGenericType = field.getGenericType();

if (Collection.class.isAssignableFrom(fieldType)) {
nodes.add(new Node(fieldName, documentName, Collections.emptyList(), Node.Type.LIST, field));
Expand All @@ -120,21 +125,21 @@ public List<Node> reflectSingle(Class objectClass) {
return nodes;
}

List<Node> processParameterizedTypes(Type type) {
List<Node> processParameterizedTypes(Type type, HashMap<Class, List<Node>> building) {
if (type instanceof Class) {
List<Node> children = reflectRecursive((Class) type);
List<Node> children = buildRecursive((Class) type, building);
if (!children.isEmpty()) return Collections.singletonList(new Node(null, children, Node.Type.DOCUMENT));

} else if (type instanceof ParameterizedType) {
ParameterizedType subType = (ParameterizedType) type;
Class rawType = (Class) subType.getRawType();

if (Collection.class.isAssignableFrom(rawType)) {
List<Node> children = processParameterizedTypes(subType.getActualTypeArguments()[0]);
List<Node> children = processParameterizedTypes(subType.getActualTypeArguments()[0], building);
if (!children.isEmpty()) return Collections.singletonList(new Node(null, children, Node.Type.LIST));

} else if (Map.class.isAssignableFrom(rawType)) {
List<Node> children = processParameterizedTypes(subType.getActualTypeArguments()[1]);
List<Node> children = processParameterizedTypes(subType.getActualTypeArguments()[1], building);
if (!children.isEmpty()) return Collections.singletonList(new Node(null, children, Node.Type.MAP));

} else {
Expand Down Expand Up @@ -171,4 +176,24 @@ static String parseFieldAnnotation(java.lang.reflect.Field field, String fieldNa
}
return fieldName;
}

// same as ClassUtils.isPrimitiveOrWrapper(), but also includes String
public static boolean isPrimitive(Class clazz) {
return clazz.isPrimitive() || primitiveClasses.contains(clazz);
}

private static Set<Class> primitiveClasses = new HashSet<>();

static {
primitiveClasses.add(Boolean.class);
primitiveClasses.add(Byte.class);
primitiveClasses.add(Character.class);
primitiveClasses.add(Double.class);
primitiveClasses.add(Float.class);
primitiveClasses.add(Integer.class);
primitiveClasses.add(Long.class);
primitiveClasses.add(Short.class);
primitiveClasses.add(Void.class);
primitiveClasses.add(String.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import java.util.Map;
import java.util.function.Function;

import static com.bol.reflection.Node.Type.*;
import static com.bol.reflection.Node.Type.DIRECT;
import static com.bol.reflection.ReflectionCache.isPrimitive;

// FIXME: check if we could bring CachedEncryptionEventListener and ReflectionEncryptionEventListener closer together; they are after all doing the same thing, just one at startup, one runtime
/**
* This is a reimplementation of {@link CachedEncryptionEventListener}, to support polymorphism.
* This means that while instead of walking by pre-cached class reflection, we have to walk by the Document provided and
Expand Down Expand Up @@ -62,8 +62,8 @@ void cryptDocument(Document document, Class clazz, Function<Object, Object> cryp
}

void diveInto(Object value, Type type, Function<Object, Object> crypt) {
// primitive type, nothing to do here
if (value.getClass().getPackage().getName().equals("java.lang")) return;
// java primitive type; ignore
if (isPrimitive(value.getClass())) return;

Class reflectiveClass = null;
Type[] typeArguments = null;
Expand Down
44 changes: 44 additions & 0 deletions src/test/java/com/bol/system/EncryptSystemTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,21 @@
import org.bson.types.ObjectId;
import org.junit.Before;
import org.junit.Test;
import org.junit.internal.Throwables;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.util.ReflectionTestUtils;

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

import static com.bol.crypt.CryptVault.fromSignedByte;
import static com.bol.system.model.InitBean.*;
import static com.bol.system.model.MyBean.MONGO_NONSENSITIVEDATA;
import static com.bol.system.model.MyBean.MONGO_SECRETSTRING;
import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -38,6 +44,7 @@ public void cleanDb() {
mongoTemplate.dropCollection(Person.class);
mongoTemplate.dropCollection(RenamedField.class);
mongoTemplate.dropCollection(PrimitiveField.class);
mongoTemplate.dropCollection(InitBean.class);
}

@Test
Expand Down Expand Up @@ -650,6 +657,43 @@ public void checkMultipleEncryptVersion() {
}
}

// ReflectionCache is not initialized yet, and we hammer building it in parallel
@Test
public void checkParallelInitialization() {
int nThreads = Math.max(4, Runtime.getRuntime().availableProcessors());
ExecutorService executorService = Executors.newFixedThreadPool(nThreads);

ArrayList<Future<String>> futures = new ArrayList<>();

for (int i = 0; i < nThreads; i++) {
futures.add(
executorService.submit(() -> {
InitBean initBean = new InitBean();
initBean.addSubBean("my data 2");
initBean.data1 = "my data 1";
mongoTemplate.save(initBean);
return initBean.id;
})
);
}

futures.forEach(f -> {
try {
String id = f.get(10, TimeUnit.SECONDS);
assertThat(id).isNotNull();

Document fromMongo = mongoTemplate.getCollection(InitBean.MONGO_INITBEAN).find(new Document("_id", new ObjectId(id))).first();
Object data1 = fromMongo.get(MONGO_DATA1);
assertThat(data1).isInstanceOf(Binary.class);
List list = (List)fromMongo.get(MONGO_SUB_BEANS);
assertThat(list).hasSize(1);
Document subBean = (Document)list.get(0);
assertThat(subBean.get(MONGO_DATA2)).isInstanceOf(Binary.class);
} catch (Exception e) {
}
});
}

byte[] cryptedResultInDb(String value) {
MyBean bean = new MyBean();
bean.secretString = value;
Expand Down
36 changes: 36 additions & 0 deletions src/test/java/com/bol/system/model/InitBean.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.bol.system.model;

import com.bol.secure.Encrypted;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;

import java.util.ArrayList;
import java.util.List;

import static com.bol.system.model.InitBean.MONGO_INITBEAN;

@Document(collection = MONGO_INITBEAN)
public class InitBean {
public static final String MONGO_INITBEAN = "initbean";
public static final String MONGO_SUB_BEANS = "subBeans";
public static final String MONGO_DATA1 = "data1";
public static final String MONGO_DATA2 = "data2";

@Id
public String id;
@Encrypted
public String data1;
public List<InitSubBean> subBeans = new ArrayList<>();

// this also tests non-public subdocuments
public void addSubBean(String input) {
InitSubBean initSubBean = new InitSubBean();
initSubBean.data2 = input;
subBeans.add(initSubBean);
}
}

class InitSubBean {
@Encrypted
public String data2;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.bol.system.polymorphism;

import com.bol.crypt.CryptVault;
import com.bol.system.model.Person;
import com.bol.system.polymorphism.model.SubObject;
import com.bol.system.polymorphism.model.TestObject;
Expand All @@ -27,10 +26,7 @@
@SpringBootTest(classes = {ReflectionMongoDBConfiguration.class})
public class PolymorphismSystemTest {

@Autowired
MongoTemplate mongoTemplate;
@Autowired
CryptVault cryptVault;
@Autowired MongoTemplate mongoTemplate;

@Before
public void cleanDb() {
Expand Down

0 comments on commit c47d562

Please sign in to comment.