From 2f870c72fe7b14a03c7ed0df33df29cfbe25f263 Mon Sep 17 00:00:00 2001 From: Shikhar Jain <8859327+shikharj05@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:47:59 +0530 Subject: [PATCH] Refactor SafeSerializationUtils for better performance (#4973) Signed-off-by: shikharj05 <8859327+shikharj05@users.noreply.github.com> --- .../support/SafeSerializationUtils.java | 25 ++-- .../support/SafeSerializationUtilsTest.java | 119 ++++++++++++++++++ 2 files changed, 136 insertions(+), 8 deletions(-) create mode 100644 src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java diff --git a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java index b58e4afd35..de55334a99 100644 --- a/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java +++ b/src/main/java/org/opensearch/security/support/SafeSerializationUtils.java @@ -17,12 +17,11 @@ import java.net.SocketAddress; import java.util.Collection; import java.util.Collections; -import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Pattern; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import org.opensearch.security.auth.UserInjector; @@ -57,7 +56,7 @@ public final class SafeSerializationUtils { LdapAttribute.class ); - private static final List> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableList.of( + private static final Set> SAFE_ASSIGNABLE_FROM_CLASSES = ImmutableSet.of( InetAddress.class, Number.class, Collection.class, @@ -66,12 +65,23 @@ public final class SafeSerializationUtils { ); private static final Set SAFE_CLASS_NAMES = Collections.singleton("org.ldaptive.LdapAttribute$LdapAttributeValues"); + static final Map, Boolean> safeClassCache = new ConcurrentHashMap<>(); static boolean isSafeClass(Class cls) { - return cls.isArray() - || SAFE_CLASSES.contains(cls) - || SAFE_CLASS_NAMES.contains(cls.getName()) - || SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); + return safeClassCache.computeIfAbsent(cls, SafeSerializationUtils::computeIsSafeClass); + } + + static boolean computeIsSafeClass(Class cls) { + return cls.isArray() || SAFE_CLASSES.contains(cls) || SAFE_CLASS_NAMES.contains(cls.getName()) || isAssignableFromSafeClass(cls); + } + + private static boolean isAssignableFromSafeClass(Class cls) { + for (Class safeClass : SAFE_ASSIGNABLE_FROM_CLASSES) { + if (safeClass.isAssignableFrom(cls)) { + return true; + } + } + return false; } static void prohibitUnsafeClasses(Class clazz) throws IOException { @@ -79,5 +89,4 @@ static void prohibitUnsafeClasses(Class clazz) throws IOException { throw new IOException("Unauthorized serialization attempt " + clazz.getName()); } } - } diff --git a/src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java b/src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java new file mode 100644 index 0000000000..f69d4e0291 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/SafeSerializationUtilsTest.java @@ -0,0 +1,119 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.security.support; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.regex.Pattern; + +import org.junit.Test; + +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.User; + +import com.amazon.dlic.auth.ldap.LdapUser; +import org.ldaptive.AbstractLdapBean; +import org.ldaptive.LdapAttribute; +import org.ldaptive.LdapEntry; +import org.ldaptive.SearchEntry; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class SafeSerializationUtilsTest { + + @Test + public void testSafeClasses() { + assertTrue(SafeSerializationUtils.isSafeClass(String.class)); + assertTrue(SafeSerializationUtils.isSafeClass(InetSocketAddress.class)); + assertTrue(SafeSerializationUtils.isSafeClass(Pattern.class)); + assertTrue(SafeSerializationUtils.isSafeClass(User.class)); + assertTrue(SafeSerializationUtils.isSafeClass(UserInjector.InjectedUser.class)); + assertTrue(SafeSerializationUtils.isSafeClass(SourceFieldsContext.class)); + assertTrue(SafeSerializationUtils.isSafeClass(LdapUser.class)); + assertTrue(SafeSerializationUtils.isSafeClass(SearchEntry.class)); + assertTrue(SafeSerializationUtils.isSafeClass(LdapEntry.class)); + assertTrue(SafeSerializationUtils.isSafeClass(AbstractLdapBean.class)); + assertTrue(SafeSerializationUtils.isSafeClass(LdapAttribute.class)); + } + + @Test + public void testSafeAssignableClasses() { + assertTrue(SafeSerializationUtils.isSafeClass(InetAddress.class)); + assertTrue(SafeSerializationUtils.isSafeClass(Integer.class)); + assertTrue(SafeSerializationUtils.isSafeClass(ArrayList.class)); + assertTrue(SafeSerializationUtils.isSafeClass(HashMap.class)); + assertTrue(SafeSerializationUtils.isSafeClass(Enum.class)); + } + + @Test + public void testArraysAreSafe() { + assertTrue(SafeSerializationUtils.isSafeClass(String[].class)); + assertTrue(SafeSerializationUtils.isSafeClass(int[].class)); + assertTrue(SafeSerializationUtils.isSafeClass(Object[].class)); + } + + @Test + public void testUnsafeClasses() { + assertFalse(SafeSerializationUtils.isSafeClass(SafeSerializationUtilsTest.class)); + assertFalse(SafeSerializationUtils.isSafeClass(Runtime.class)); + } + + @Test + public void testProhibitUnsafeClasses() { + try { + SafeSerializationUtils.prohibitUnsafeClasses(String.class); + } catch (IOException e) { + fail("Should not throw exception for safe class"); + } + + try { + SafeSerializationUtils.prohibitUnsafeClasses(SafeSerializationUtilsTest.class); + fail("Should throw exception for unsafe class"); + } catch (IOException e) { + assertEquals("Unauthorized serialization attempt " + SafeSerializationUtilsTest.class.getName(), e.getMessage()); + } + } + + @Test + public void testInheritance() { + class CustomArrayList extends ArrayList {} + assertTrue(SafeSerializationUtils.isSafeClass(CustomArrayList.class)); + + class CustomMap extends HashMap {} + assertTrue(SafeSerializationUtils.isSafeClass(CustomMap.class)); + } + + @Test + public void testCaching() { + // First call should compute the result + boolean result1 = SafeSerializationUtils.isSafeClass(String.class); + assertTrue(result1); + + // Second call should use cached result + boolean result2 = SafeSerializationUtils.isSafeClass(String.class); + assertTrue(result2); + + // Verify that the cache was used (size should be 1) + assertEquals(1, SafeSerializationUtils.safeClassCache.size()); + + // Third call for a different class + boolean result3 = SafeSerializationUtils.isSafeClass(Integer.class); + assertTrue(result3); + // Verify that the cache was updated + assertEquals(2, SafeSerializationUtils.safeClassCache.size()); + } +}