From c47d562a751e4a02a2f227fa1aa51ab001b6bf7f Mon Sep 17 00:00:00 2001 From: Agoston Horvath Date: Sat, 24 Apr 2021 15:17:06 +0200 Subject: [PATCH] Refactor to fix bug that happens during parallel initialization of ReflectionCache. Fixes #47 --- pom.xml | 2 +- .../com/bol/reflection/ReflectionCache.java | 77 ++++++++++++------- .../ReflectionEncryptionEventListener.java | 8 +- .../com/bol/system/EncryptSystemTest.java | 44 +++++++++++ .../java/com/bol/system/model/InitBean.java | 36 +++++++++ .../polymorphism/PolymorphismSystemTest.java | 6 +- 6 files changed, 137 insertions(+), 36 deletions(-) create mode 100644 src/test/java/com/bol/system/model/InitBean.java diff --git a/pom.xml b/pom.xml index 84ea665..c3920fd 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ spring-data-mongodb-encrypt jar spring-data-mongodb-encrypt - 2.6.1 + 2.6.2 High performance, per-field encryption for spring-data-mongodb https://github.com/agoston/spring-data-mongodb-encrypt diff --git a/src/main/java/com/bol/reflection/ReflectionCache.java b/src/main/java/com/bol/reflection/ReflectionCache.java index fb56789..05847c5 100644 --- a/src/main/java/com/bol/reflection/ReflectionCache.java +++ b/src/main/java/com/bol/reflection/ReflectionCache.java @@ -4,7 +4,6 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.data.mongodb.core.mapping.Field; -import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; import java.lang.reflect.Modifier; @@ -17,21 +16,30 @@ public class ReflectionCache { private static final Logger LOG = LoggerFactory.getLogger(ReflectionCache.class); - private Map> reflectionCache = new ConcurrentHashMap<>(); + private ConcurrentHashMap> reflectionCache = new ConcurrentHashMap<>(); // used by CachedEncryptionEventListener to gather metadata of a class and all it fields, recursively. public List reflectRecursive(Class objectClass) { - List result = reflectionCache.get(objectClass); - if (result != null) { - LOG.trace("cyclic reference found; {} is already mapped", objectClass.getName()); - return result; + List nodes = reflectionCache.get(objectClass); + if (nodes != null) return nodes; + + synchronized (this) { + return buildRecursive(objectClass, new HashMap<>()); } + } + + // building is necessary to avoid putting half-processed data in `reflectionCache` (where it would be returned to other threads) + private List buildRecursive(Class objectClass, HashMap> building) { + if (isPrimitive(objectClass)) return Collections.emptyList(); + + List processed = reflectionCache.get(objectClass); + if (processed != null) return processed; - // java primitive type; ignore - if (ClassUtils.isPrimitiveOrWrapper(objectClass)) return Collections.emptyList(); + List processing = building.get(objectClass); + if (processing != null) return processing; List nodes = new ArrayList<>(); - reflectionCache.put(objectClass, nodes); + building.put(objectClass, nodes); ReflectionUtils.doWithFields(objectClass, field -> { String fieldName = field.getName(); @@ -49,16 +57,16 @@ public List reflectRecursive(Class objectClass) { Type fieldGenericType = field.getGenericType(); if (Collection.class.isAssignableFrom(fieldType)) { - List children = processParameterizedTypes(fieldGenericType); + List children = processParameterizedTypes(fieldGenericType, building); if (!children.isEmpty()) nodes.add(new Node(fieldName, documentName, unwrap(children), Node.Type.LIST, field)); } else if (Map.class.isAssignableFrom(fieldType)) { - List children = processParameterizedTypes(fieldGenericType); + List children = processParameterizedTypes(fieldGenericType, building); if (!children.isEmpty()) nodes.add(new Node(fieldName, documentName, unwrap(children), Node.Type.MAP, field)); } else { // descending into sub-documents - List children = reflectRecursive(fieldType); + List children = buildRecursive(fieldType, building); if (!children.isEmpty()) nodes.add(new Node(fieldName, documentName, children, Node.Type.DOCUMENT, field)); } } @@ -68,23 +76,21 @@ public List reflectRecursive(Class objectClass) { } }); + reflectionCache.put(objectClass, nodes); + return nodes; } // used by ReflectionEncryptionEventListener to map a single Document - // FIXME: this is a slimmed down copy-paste of reflectRecursive(); find a way to bring Cached and Reflective listener closer together! public List reflectSingle(Class objectClass) { - List result = reflectionCache.get(objectClass); - if (result != null) { - LOG.trace("cyclic reference found; {} is already mapped", objectClass.getName()); - return result; - } + return reflectionCache.computeIfAbsent(objectClass, this::buildSingle); + } - // java primitive type; ignore - if (objectClass.getPackage().getName().equals("java.lang")) return Collections.emptyList(); + // FIXME: this is a slimmed down copy-paste of buildRecursive(); find a way to bring Cached and Reflective listener closer together! + private List buildSingle(Class objectClass) { + if (isPrimitive(objectClass)) return Collections.emptyList(); List nodes = new ArrayList<>(); - reflectionCache.put(objectClass, nodes); ReflectionUtils.doWithFields(objectClass, field -> { String fieldName = field.getName(); @@ -99,7 +105,6 @@ public List reflectSingle(Class objectClass) { } else { Class fieldType = field.getType(); - Type fieldGenericType = field.getGenericType(); if (Collection.class.isAssignableFrom(fieldType)) { nodes.add(new Node(fieldName, documentName, Collections.emptyList(), Node.Type.LIST, field)); @@ -120,9 +125,9 @@ public List reflectSingle(Class objectClass) { return nodes; } - List processParameterizedTypes(Type type) { + List processParameterizedTypes(Type type, HashMap> building) { if (type instanceof Class) { - List children = reflectRecursive((Class) type); + List children = buildRecursive((Class) type, building); if (!children.isEmpty()) return Collections.singletonList(new Node(null, children, Node.Type.DOCUMENT)); } else if (type instanceof ParameterizedType) { @@ -130,11 +135,11 @@ List processParameterizedTypes(Type type) { Class rawType = (Class) subType.getRawType(); if (Collection.class.isAssignableFrom(rawType)) { - List children = processParameterizedTypes(subType.getActualTypeArguments()[0]); + List children = processParameterizedTypes(subType.getActualTypeArguments()[0], building); if (!children.isEmpty()) return Collections.singletonList(new Node(null, children, Node.Type.LIST)); } else if (Map.class.isAssignableFrom(rawType)) { - List children = processParameterizedTypes(subType.getActualTypeArguments()[1]); + List children = processParameterizedTypes(subType.getActualTypeArguments()[1], building); if (!children.isEmpty()) return Collections.singletonList(new Node(null, children, Node.Type.MAP)); } else { @@ -171,4 +176,24 @@ static String parseFieldAnnotation(java.lang.reflect.Field field, String fieldNa } return fieldName; } + + // same as ClassUtils.isPrimitiveOrWrapper(), but also includes String + public static boolean isPrimitive(Class clazz) { + return clazz.isPrimitive() || primitiveClasses.contains(clazz); + } + + private static Set primitiveClasses = new HashSet<>(); + + static { + primitiveClasses.add(Boolean.class); + primitiveClasses.add(Byte.class); + primitiveClasses.add(Character.class); + primitiveClasses.add(Double.class); + primitiveClasses.add(Float.class); + primitiveClasses.add(Integer.class); + primitiveClasses.add(Long.class); + primitiveClasses.add(Short.class); + primitiveClasses.add(Void.class); + primitiveClasses.add(String.class); + } } diff --git a/src/main/java/com/bol/secure/ReflectionEncryptionEventListener.java b/src/main/java/com/bol/secure/ReflectionEncryptionEventListener.java index 6edaf98..9fa139b 100644 --- a/src/main/java/com/bol/secure/ReflectionEncryptionEventListener.java +++ b/src/main/java/com/bol/secure/ReflectionEncryptionEventListener.java @@ -16,9 +16,9 @@ import java.util.Map; import java.util.function.Function; -import static com.bol.reflection.Node.Type.*; +import static com.bol.reflection.Node.Type.DIRECT; +import static com.bol.reflection.ReflectionCache.isPrimitive; -// FIXME: check if we could bring CachedEncryptionEventListener and ReflectionEncryptionEventListener closer together; they are after all doing the same thing, just one at startup, one runtime /** * This is a reimplementation of {@link CachedEncryptionEventListener}, to support polymorphism. * This means that while instead of walking by pre-cached class reflection, we have to walk by the Document provided and @@ -62,8 +62,8 @@ void cryptDocument(Document document, Class clazz, Function cryp } void diveInto(Object value, Type type, Function crypt) { - // primitive type, nothing to do here - if (value.getClass().getPackage().getName().equals("java.lang")) return; + // java primitive type; ignore + if (isPrimitive(value.getClass())) return; Class reflectiveClass = null; Type[] typeArguments = null; diff --git a/src/test/java/com/bol/system/EncryptSystemTest.java b/src/test/java/com/bol/system/EncryptSystemTest.java index 82a1e05..b7d1e43 100644 --- a/src/test/java/com/bol/system/EncryptSystemTest.java +++ b/src/test/java/com/bol/system/EncryptSystemTest.java @@ -10,6 +10,7 @@ import org.bson.types.ObjectId; import org.junit.Before; import org.junit.Test; +import org.junit.internal.Throwables; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.query.Query; @@ -17,8 +18,13 @@ import org.springframework.test.util.ReflectionTestUtils; import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import static com.bol.crypt.CryptVault.fromSignedByte; +import static com.bol.system.model.InitBean.*; import static com.bol.system.model.MyBean.MONGO_NONSENSITIVEDATA; import static com.bol.system.model.MyBean.MONGO_SECRETSTRING; import static org.assertj.core.api.Assertions.assertThat; @@ -38,6 +44,7 @@ public void cleanDb() { mongoTemplate.dropCollection(Person.class); mongoTemplate.dropCollection(RenamedField.class); mongoTemplate.dropCollection(PrimitiveField.class); + mongoTemplate.dropCollection(InitBean.class); } @Test @@ -650,6 +657,43 @@ public void checkMultipleEncryptVersion() { } } + // ReflectionCache is not initialized yet, and we hammer building it in parallel + @Test + public void checkParallelInitialization() { + int nThreads = Math.max(4, Runtime.getRuntime().availableProcessors()); + ExecutorService executorService = Executors.newFixedThreadPool(nThreads); + + ArrayList> futures = new ArrayList<>(); + + for (int i = 0; i < nThreads; i++) { + futures.add( + executorService.submit(() -> { + InitBean initBean = new InitBean(); + initBean.addSubBean("my data 2"); + initBean.data1 = "my data 1"; + mongoTemplate.save(initBean); + return initBean.id; + }) + ); + } + + futures.forEach(f -> { + try { + String id = f.get(10, TimeUnit.SECONDS); + assertThat(id).isNotNull(); + + Document fromMongo = mongoTemplate.getCollection(InitBean.MONGO_INITBEAN).find(new Document("_id", new ObjectId(id))).first(); + Object data1 = fromMongo.get(MONGO_DATA1); + assertThat(data1).isInstanceOf(Binary.class); + List list = (List)fromMongo.get(MONGO_SUB_BEANS); + assertThat(list).hasSize(1); + Document subBean = (Document)list.get(0); + assertThat(subBean.get(MONGO_DATA2)).isInstanceOf(Binary.class); + } catch (Exception e) { + } + }); + } + byte[] cryptedResultInDb(String value) { MyBean bean = new MyBean(); bean.secretString = value; diff --git a/src/test/java/com/bol/system/model/InitBean.java b/src/test/java/com/bol/system/model/InitBean.java new file mode 100644 index 0000000..8e4d094 --- /dev/null +++ b/src/test/java/com/bol/system/model/InitBean.java @@ -0,0 +1,36 @@ +package com.bol.system.model; + +import com.bol.secure.Encrypted; +import org.springframework.data.annotation.Id; +import org.springframework.data.mongodb.core.mapping.Document; + +import java.util.ArrayList; +import java.util.List; + +import static com.bol.system.model.InitBean.MONGO_INITBEAN; + +@Document(collection = MONGO_INITBEAN) +public class InitBean { + public static final String MONGO_INITBEAN = "initbean"; + public static final String MONGO_SUB_BEANS = "subBeans"; + public static final String MONGO_DATA1 = "data1"; + public static final String MONGO_DATA2 = "data2"; + + @Id + public String id; + @Encrypted + public String data1; + public List subBeans = new ArrayList<>(); + + // this also tests non-public subdocuments + public void addSubBean(String input) { + InitSubBean initSubBean = new InitSubBean(); + initSubBean.data2 = input; + subBeans.add(initSubBean); + } +} + +class InitSubBean { + @Encrypted + public String data2; +} diff --git a/src/test/java/com/bol/system/polymorphism/PolymorphismSystemTest.java b/src/test/java/com/bol/system/polymorphism/PolymorphismSystemTest.java index c7c06db..48af8ba 100644 --- a/src/test/java/com/bol/system/polymorphism/PolymorphismSystemTest.java +++ b/src/test/java/com/bol/system/polymorphism/PolymorphismSystemTest.java @@ -1,6 +1,5 @@ package com.bol.system.polymorphism; -import com.bol.crypt.CryptVault; import com.bol.system.model.Person; import com.bol.system.polymorphism.model.SubObject; import com.bol.system.polymorphism.model.TestObject; @@ -27,10 +26,7 @@ @SpringBootTest(classes = {ReflectionMongoDBConfiguration.class}) public class PolymorphismSystemTest { - @Autowired - MongoTemplate mongoTemplate; - @Autowired - CryptVault cryptVault; + @Autowired MongoTemplate mongoTemplate; @Before public void cleanDb() {