diff --git a/build.gradle b/build.gradle index 76fb175b..dcc651f4 100644 --- a/build.gradle +++ b/build.gradle @@ -12,6 +12,7 @@ dependencies { implementation 'org.bouncycastle:bcprov-jdk18on:1.77' testImplementation 'org.junit.jupiter:junit-jupiter:5.10.2' + testImplementation 'org.mockito:mockito-core:5.14.2' } repositories { diff --git a/src/main/java/com/apple/itunes/storekit/verification/ChainVerifier.java b/src/main/java/com/apple/itunes/storekit/verification/ChainVerifier.java index 8d801050..3bf5db85 100644 --- a/src/main/java/com/apple/itunes/storekit/verification/ChainVerifier.java +++ b/src/main/java/com/apple/itunes/storekit/verification/ChainVerifier.java @@ -14,18 +14,32 @@ import java.security.cert.PKIXRevocationChecker; import java.security.cert.TrustAnchor; import java.security.cert.X509Certificate; +import java.time.Clock; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; import java.util.Base64; import java.util.Date; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; public class ChainVerifier { private static final int EXPECTED_CHAIN_LENGTH = 3; + private static final int MAXIMUM_CACHE_SIZE = 32; // There are unlikely to be more than a couple keys at once + private static final int CACHE_TIME_LIMIT = 15; // 15 minutes + private final Set trustAnchors; + private final ConcurrentHashMap, CachedEntry> verifiedPublicKeyCache; + private final Clock clock; public ChainVerifier(Set rootCertificates) { + this(rootCertificates, Clock.systemUTC()); + } + + ChainVerifier(Set rootCertificates, Clock clock) { try { CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509"); this.trustAnchors = new HashSet<>(); @@ -42,9 +56,26 @@ public ChainVerifier(Set rootCertificates) { if (trustAnchors.size() == 0) { throw new RuntimeException("At least one root certificate is required"); } + this.verifiedPublicKeyCache = new ConcurrentHashMap<>(); + this.clock = clock; } public PublicKey verifyChain(String[] certificates, boolean performRevocationChecking, Date effectiveDate) throws VerificationException { + if (performRevocationChecking && certificates.length > 0) { + // If revocation checking is enabled (which also implies effectiveDate is now), check the cache + PublicKey cachedKey = getCachedPrivateKey(Arrays.asList(certificates)); + if (cachedKey != null) { + return cachedKey; + } + } + PublicKey publicKey = verifyChainWithoutCaching(certificates, performRevocationChecking, effectiveDate); + if (performRevocationChecking) { + putVerifiedPublicKey(Arrays.asList(certificates), publicKey); + } + return publicKey; + } + + PublicKey verifyChainWithoutCaching(String[] certificates, boolean performRevocationChecking, Date effectiveDate) throws VerificationException { CertificateFactory certificateFactory; CertPathValidator certPathValidator; try { @@ -85,4 +116,30 @@ public PublicKey verifyChain(String[] certificates, boolean performRevocationChe throw new VerificationException(VerificationStatus.INVALID_CHAIN, e); } } + + private PublicKey getCachedPrivateKey(List certificateChain) { + if (verifiedPublicKeyCache.containsKey(certificateChain) && verifiedPublicKeyCache.get(certificateChain).cachedExpirationDate.isAfter(clock.instant())) { + return verifiedPublicKeyCache.get(certificateChain).publicKey; + } + return null; + } + + private void putVerifiedPublicKey(List certificateChain, PublicKey publicKey) { + Instant cacheExpiration = clock.instant().plus(CACHE_TIME_LIMIT, ChronoUnit.MINUTES); + verifiedPublicKeyCache.put(certificateChain, new CachedEntry(cacheExpiration, publicKey)); + if (verifiedPublicKeyCache.size() > MAXIMUM_CACHE_SIZE) { + // In the very unlikely event that the map has become too large, clear out old entries + verifiedPublicKeyCache.entrySet().removeIf(e -> e.getValue().cachedExpirationDate.isBefore(clock.instant())); + } + } + + private static class CachedEntry { + private final Instant cachedExpirationDate; + private final PublicKey publicKey; + + public CachedEntry(Instant cachedExpirationDate, PublicKey publicKey) { + this.cachedExpirationDate = cachedExpirationDate; + this.publicKey = publicKey; + } + } } diff --git a/src/test/java/com/apple/itunes/storekit/verification/ChainVerifierTest.java b/src/test/java/com/apple/itunes/storekit/verification/ChainVerifierTest.java index 750049a2..e9f5bd69 100644 --- a/src/test/java/com/apple/itunes/storekit/verification/ChainVerifierTest.java +++ b/src/test/java/com/apple/itunes/storekit/verification/ChainVerifierTest.java @@ -3,12 +3,16 @@ package com.apple.itunes.storekit.verification; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import java.io.ByteArrayInputStream; import java.nio.charset.StandardCharsets; import java.security.PublicKey; import java.security.cert.CertPathValidatorException; +import java.time.Clock; +import java.time.Instant; import java.util.Base64; import java.util.Date; import java.util.Set; @@ -31,6 +35,23 @@ public class ChainVerifierTest { private static final Date EFFECTIVE_DATE = new Date(1681312846000L); // April 2023 + private Clock clock; + private PublicKey publicKey; + private ChainVerifier mockedChainVerifier; + private static final long CLOCK_DATE = 41231L; + + @BeforeEach + public void setup() throws VerificationException { + clock = Mockito.mock(Clock.class); + publicKey = Mockito.mock(PublicKey.class); + mockedChainVerifier = Mockito.spy(getChainVerifier(ROOT_CA_BASE64_ENCODED)); + Mockito.doReturn(publicKey) + .when(mockedChainVerifier) + .verifyChainWithoutCaching(Mockito.any(), Mockito.anyBoolean(), Mockito.any()); + Mockito.when(clock.instant()).thenReturn(Instant.ofEpochMilli(CLOCK_DATE)); + + } + @Test public void testValidChainWithoutOCSP() throws VerificationException { ChainVerifier verifier = getChainVerifier(ROOT_CA_BASE64_ENCODED); @@ -137,6 +158,73 @@ public void testChainDifferentThanRootCertificate() { Assertions.assertInstanceOf(CertPathValidatorException.class, cause); } + @Test + public void testOcspResponseCaching() throws VerificationException { + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + ROOT_CA_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + Mockito.verify(mockedChainVerifier, Mockito.times(1)).verifyChainWithoutCaching(Mockito.any(), Mockito.anyBoolean(), Mockito.any()); + // Move one second to the future, should be cached + Mockito.when(clock.instant()).thenReturn(Instant.ofEpochMilli(CLOCK_DATE + 1_000)); // 1 second + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + ROOT_CA_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + Mockito.verify(mockedChainVerifier, Mockito.times(1)).verifyChainWithoutCaching(Mockito.any(), Mockito.anyBoolean(), Mockito.any()); + } + + @Test + public void testOcspResponseCachingHasExpiration() throws VerificationException { + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + ROOT_CA_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + // Move 15 minutes into the future (such that the cache has expired) + Mockito.when(clock.instant()).thenReturn(Instant.ofEpochMilli(CLOCK_DATE + 900_000)); // 15 minutes + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + ROOT_CA_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + Mockito.verify(mockedChainVerifier, Mockito.times(2)).verifyChainWithoutCaching(Mockito.any(), Mockito.anyBoolean(), Mockito.any()); + } + + @Test + public void testOcspResponseCachingWithDifferentChains() throws VerificationException { + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + ROOT_CA_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + // Different certificates result in different cache entry + mockedChainVerifier.verifyChain(new String[] { + REAL_APPLE_SIGNING_CERTIFICATE_BASE64_ENCODED, + REAL_APPLE_INTERMEDIATE_BASE64_ENCODED, + REAL_APPLE_ROOT_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + Mockito.verify(mockedChainVerifier, Mockito.times(2)).verifyChainWithoutCaching(Mockito.any(), Mockito.anyBoolean(), Mockito.any()); + } + + @Test + public void testOcspResponseCachingWithSlightlyDifferentChains() throws VerificationException { + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + ROOT_CA_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + // Different certificates result in different cache entry + mockedChainVerifier.verifyChain(new String[] { + LEAF_CERT_BASE64_ENCODED, + INTERMEDIATE_CA_BASE64_ENCODED, + REAL_APPLE_ROOT_BASE64_ENCODED + }, true, EFFECTIVE_DATE); + Mockito.verify(mockedChainVerifier, Mockito.times(2)).verifyChainWithoutCaching(Mockito.any(), Mockito.anyBoolean(), Mockito.any()); + } + /** * The following test will communicate with Apple's OCSP servers, disable this test for offline testing */ @@ -150,7 +238,7 @@ public void testAppleChainIsValidWithOCSP() throws VerificationException { }, true, EFFECTIVE_DATE); } - private static ChainVerifier getChainVerifier(String base64EncodedRootCertificate) { - return new ChainVerifier(Set.of(new ByteArrayInputStream(Base64.getDecoder().decode(base64EncodedRootCertificate)))); + private ChainVerifier getChainVerifier(String base64EncodedRootCertificate) { + return new ChainVerifier(Set.of(new ByteArrayInputStream(Base64.getDecoder().decode(base64EncodedRootCertificate))), clock); } }