diff --git a/etor/src/main/java/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSender.java b/etor/src/main/java/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSender.java index b0461fbdf..face1ca87 100644 --- a/etor/src/main/java/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSender.java +++ b/etor/src/main/java/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSender.java @@ -39,28 +39,19 @@ public class ReportStreamOrderSender implements OrderSender { private static final String OUR_PRIVATE_KEY_ID = "trusted-intermediary-private-key-" + ApplicationContext.getEnvironment(); + private static final String RS_TOKEN_CACHE_ID = "report-stream-token"; private static final String CLIENT_NAME = "flexion.etor-service-sender"; private static final Map RS_AUTH_API_HEADERS = Map.of("Content-Type", "application/x-www-form-urlencoded"); - private String rsTokenCache; - - protected synchronized String getRsTokenCache() { - return this.rsTokenCache; - } - - protected synchronized void setRsTokenCache(String token) { - this.rsTokenCache = token; - } - @Inject private HttpClient client; @Inject private AuthEngine jwt; @Inject private Formatter formatter; @Inject private HapiFhir fhir; @Inject private Logger logger; @Inject private Secrets secrets; - @Inject private Cache keyCache; + @Inject private Cache cache; public static ReportStreamOrderSender getInstance() { return INSTANCE; @@ -93,19 +84,22 @@ protected void logRsSubmissionId(String rsResponseBody) { protected String getRsToken() throws UnableToSendOrderException { logger.logInfo("Looking up ReportStream token"); - if (getRsTokenCache() != null && isValidToken()) { + + var token = cache.get(RS_TOKEN_CACHE_ID); + + if (token != null && isValidToken(token)) { logger.logDebug("valid cache token"); - return getRsTokenCache(); + return token; } - String token = requestToken(); - setRsTokenCache(token); + token = requestToken(); + + cache.put(RS_TOKEN_CACHE_ID, token); return token; } - protected boolean isValidToken() { - String token = getRsTokenCache(); + protected boolean isValidToken(String token) { LocalDateTime expirationDate = jwt.getExpirationDate(token); return LocalDateTime.now().isBefore(expirationDate.minus(15, ChronoUnit.SECONDS)); @@ -164,7 +158,7 @@ protected String requestToken() throws UnableToSendOrderException { } protected String retrievePrivateKey() throws SecretRetrievalException { - String key = keyCache.get(OUR_PRIVATE_KEY_ID); + String key = cache.get(OUR_PRIVATE_KEY_ID); if (key != null) { return key; } @@ -175,12 +169,12 @@ protected String retrievePrivateKey() throws SecretRetrievalException { } void cacheOurPrivateKeyIfNotCachedAlready(String privateKey) { - String key = keyCache.get(OUR_PRIVATE_KEY_ID); + String key = cache.get(OUR_PRIVATE_KEY_ID); if (key != null) { return; } - keyCache.put(OUR_PRIVATE_KEY_ID, privateKey); + cache.put(OUR_PRIVATE_KEY_ID, privateKey); } protected String extractToken(String responseBody) throws FormatterProcessingException { diff --git a/etor/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSenderTest.groovy b/etor/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSenderTest.groovy index 38e6c47cd..f80b70244 100644 --- a/etor/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSenderTest.groovy +++ b/etor/src/test/groovy/gov/hhs/cdc/trustedintermediary/external/reportstream/ReportStreamOrderSenderTest.groovy @@ -8,20 +8,18 @@ import gov.hhs.cdc.trustedintermediary.external.inmemory.KeyCache import gov.hhs.cdc.trustedintermediary.external.jackson.Jackson import gov.hhs.cdc.trustedintermediary.wrappers.AuthEngine import gov.hhs.cdc.trustedintermediary.wrappers.Cache -import gov.hhs.cdc.trustedintermediary.wrappers.Logger -import gov.hhs.cdc.trustedintermediary.wrappers.formatter.Formatter -import gov.hhs.cdc.trustedintermediary.wrappers.formatter.FormatterProcessingException import gov.hhs.cdc.trustedintermediary.wrappers.HapiFhir import gov.hhs.cdc.trustedintermediary.wrappers.HttpClient import gov.hhs.cdc.trustedintermediary.wrappers.HttpClientException +import gov.hhs.cdc.trustedintermediary.wrappers.Logger import gov.hhs.cdc.trustedintermediary.wrappers.Secrets +import gov.hhs.cdc.trustedintermediary.wrappers.formatter.Formatter +import gov.hhs.cdc.trustedintermediary.wrappers.formatter.FormatterProcessingException import gov.hhs.cdc.trustedintermediary.wrappers.formatter.TypeReference import java.time.LocalDateTime import java.time.temporal.ChronoUnit import spock.lang.Specification -import java.util.concurrent.ConcurrentHashMap - class ReportStreamOrderSenderTest extends Specification { def setup() { @@ -310,61 +308,19 @@ class ReportStreamOrderSenderTest extends Specification { def "ensure jwt that expires 15 seconds from now is valid"() { given: def mockAuthEngine = Mock(AuthEngine) - TestApplicationContext.register(AuthEngine, mockAuthEngine) + mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(20, ChronoUnit.SECONDS) - TestApplicationContext.register(OrderSender, ReportStreamOrderSender.getInstance()) + + TestApplicationContext.register(AuthEngine, mockAuthEngine) TestApplicationContext.injectRegisteredImplementations() - ReportStreamOrderSender.getInstance().setRsTokenCache("our token from rs") when: - def isValid = ReportStreamOrderSender.getInstance().isValidToken() + def isValid = ReportStreamOrderSender.getInstance().isValidToken("our token from rs") then: isValid } - def "rsTokenCache getter and setter works, no synchronization"() { - given: - def rsOrderSender = ReportStreamOrderSender.getInstance() - def expected = "fake token" - - when: - rsOrderSender.setRsTokenCache(expected) - def actual = rsOrderSender.getRsTokenCache() - - then: - actual == expected - } - - def "rsTokenCache synchronization works"() { - given: - def orderSender = ReportStreamOrderSender.getInstance() - def threadNums = 5 - def iterations = 25 - def table = new ConcurrentHashMap() - - when: - List threads = [] - (1..threadNums).each { threadId -> - threads.add(new Thread({ - for(int i=0; i> null + + def freshTokenFromRs = "new token" + mockFormatter.convertJsonToObject(_, _ as TypeReference) >> [access_token: freshTokenFromRs] + TestApplicationContext.register(Formatter, mockFormatter) - TestApplicationContext.register(AuthEngine, mockAuthEngine) + TestApplicationContext.register(AuthEngine, Mock(AuthEngine)) TestApplicationContext.register(HttpClient, mockClient) - TestApplicationContext.register(Secrets, mockSecrets) - mockSecrets.getKey(_ as String) >> "fake private key" - TestApplicationContext.register(OrderSender, orderSender) - TestApplicationContext.injectRegisteredImplementations() + TestApplicationContext.register(Cache, mockCache) + TestApplicationContext.register(Secrets, Mock(Secrets)) - mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS) - mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token" - mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token") - def responseBody = """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}""" - mockClient.post(_ as String, _ as Map, _ as String) >> responseBody + TestApplicationContext.injectRegisteredImplementations() when: - def token = orderSender.getRsToken() + def token = ReportStreamOrderSender.getInstance().getRsToken() then: - token == orderSender.getRsTokenCache() + 1 * mockClient.post(_, _, _) + token == freshTokenFromRs } - def "getRsToken when cache token is invalid"() { + def "getRsToken when cache token is invalid we call RS to get a new one"() { given: - def orderSender = ReportStreamOrderSender.getInstance() def mockClient = Mock(HttpClient) def mockAuthEngine = Mock(AuthEngine) - def mockSecrets = Mock(Secrets) def mockFormatter = Mock(Formatter) + def mockCache = Mock(Cache) + + mockCache.get(_ as String) >> "shouldn't be returned" + + //mock the auth engine so that the JWT looks like it is invalid + mockAuthEngine.getExpirationDate(_) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS) + + def freshTokenFromRs = "new token" + mockFormatter.convertJsonToObject(_, _ as TypeReference) >> [access_token: freshTokenFromRs] + TestApplicationContext.register(Formatter, mockFormatter) TestApplicationContext.register(AuthEngine, mockAuthEngine) TestApplicationContext.register(HttpClient, mockClient) - TestApplicationContext.register(Secrets, mockSecrets) - mockSecrets.getKey(_ as String) >> "fakePrivateKey" - TestApplicationContext.register(OrderSender, orderSender) - TestApplicationContext.injectRegisteredImplementations() + TestApplicationContext.register(Cache, mockCache) + TestApplicationContext.register(Secrets, Mock(Secrets)) - mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token" - mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS) - mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token") - def responseBody = """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}""" - mockClient.post(_ as String, _ as Map, _ as String) >> responseBody - orderSender.setRsTokenCache("Invalid Token") + TestApplicationContext.injectRegisteredImplementations() when: - def token = orderSender.getRsToken() + def token = ReportStreamOrderSender.getInstance().getRsToken() then: - token == orderSender.getRsTokenCache() + 1 * mockClient.post(_, _, _) + token == freshTokenFromRs } - def "getRsToken when cache token is valid"() { + def "getRsToken when cache token is valid, return that cached token"() { given: - def orderSender = ReportStreamOrderSender.getInstance() - orderSender.setRsTokenCache("valid Token") - TestApplicationContext.register(OrderSender, orderSender) + def mockAuthEngine = Mock(AuthEngine) + def mockCache = Mock(Cache) - def mockFormatter = Mock(Formatter) - mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token") - TestApplicationContext.register(Formatter, mockFormatter) + def cachedRsToken = "DogCow goes Moof!" + mockCache.get(_ as String) >> cachedRsToken - def mockLogFormatter = Mock(Formatter) - mockLogFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> null - TestApplicationContext.register(Formatter, mockLogFormatter) + //mock the auth engine so that the JWT looks valid + mockAuthEngine.getExpirationDate(_) >> LocalDateTime.now().plus(60, ChronoUnit.SECONDS) - def mockAuthEngine = Mock(AuthEngine) - mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token" - mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(25, ChronoUnit.SECONDS) TestApplicationContext.register(AuthEngine, mockAuthEngine) - - def mockClient = Mock(HttpClient) - mockClient.post(_ as String, _ as Map, _ as String) >> """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}""" - TestApplicationContext.register(HttpClient, mockClient) - - def mockSecrets = Mock(Secrets) - mockSecrets.getKey(_ as String) >> "fakePrivateKey" - TestApplicationContext.register(Secrets, mockSecrets) + TestApplicationContext.register(Cache, mockCache) TestApplicationContext.injectRegisteredImplementations() when: - def token = orderSender.getRsToken() + def token = ReportStreamOrderSender.getInstance().getRsToken() then: - token == orderSender.getRsTokenCache() + token == cachedRsToken } def "logRsSubmissionId logs submissionId if convertJsonToObject is successful"() {