Skip to content

Commit

Permalink
Merge pull request #614 from CDCgov/rs-token-cache-improvement
Browse files Browse the repository at this point in the history
613: RS Token Cache Improvement
  • Loading branch information
halprin authored Oct 27, 2023
2 parents 6195a7d + 28d4505 commit d6429cd
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,19 @@ public class ReportStreamOrderSender implements OrderSender {

private static final String OUR_PRIVATE_KEY_ID =
"trusted-intermediary-private-key-" + ApplicationContext.getEnvironment();
private static final String RS_TOKEN_CACHE_ID = "report-stream-token";

private static final String CLIENT_NAME = "flexion.etor-service-sender";
private static final Map<String, String> RS_AUTH_API_HEADERS =
Map.of("Content-Type", "application/x-www-form-urlencoded");

private String rsTokenCache;

protected synchronized String getRsTokenCache() {
return this.rsTokenCache;
}

protected synchronized void setRsTokenCache(String token) {
this.rsTokenCache = token;
}

@Inject private HttpClient client;
@Inject private AuthEngine jwt;
@Inject private Formatter formatter;
@Inject private HapiFhir fhir;
@Inject private Logger logger;
@Inject private Secrets secrets;
@Inject private Cache keyCache;
@Inject private Cache cache;

public static ReportStreamOrderSender getInstance() {
return INSTANCE;
Expand Down Expand Up @@ -93,19 +84,22 @@ protected void logRsSubmissionId(String rsResponseBody) {

protected String getRsToken() throws UnableToSendOrderException {
logger.logInfo("Looking up ReportStream token");
if (getRsTokenCache() != null && isValidToken()) {

var token = cache.get(RS_TOKEN_CACHE_ID);

if (token != null && isValidToken(token)) {
logger.logDebug("valid cache token");
return getRsTokenCache();
return token;
}

String token = requestToken();
setRsTokenCache(token);
token = requestToken();

cache.put(RS_TOKEN_CACHE_ID, token);

return token;
}

protected boolean isValidToken() {
String token = getRsTokenCache();
protected boolean isValidToken(String token) {
LocalDateTime expirationDate = jwt.getExpirationDate(token);

return LocalDateTime.now().isBefore(expirationDate.minus(15, ChronoUnit.SECONDS));
Expand Down Expand Up @@ -164,7 +158,7 @@ protected String requestToken() throws UnableToSendOrderException {
}

protected String retrievePrivateKey() throws SecretRetrievalException {
String key = keyCache.get(OUR_PRIVATE_KEY_ID);
String key = cache.get(OUR_PRIVATE_KEY_ID);
if (key != null) {
return key;
}
Expand All @@ -175,12 +169,12 @@ protected String retrievePrivateKey() throws SecretRetrievalException {
}

void cacheOurPrivateKeyIfNotCachedAlready(String privateKey) {
String key = keyCache.get(OUR_PRIVATE_KEY_ID);
String key = cache.get(OUR_PRIVATE_KEY_ID);
if (key != null) {
return;
}

keyCache.put(OUR_PRIVATE_KEY_ID, privateKey);
cache.put(OUR_PRIVATE_KEY_ID, privateKey);
}

protected String extractToken(String responseBody) throws FormatterProcessingException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ import gov.hhs.cdc.trustedintermediary.external.inmemory.KeyCache
import gov.hhs.cdc.trustedintermediary.external.jackson.Jackson
import gov.hhs.cdc.trustedintermediary.wrappers.AuthEngine
import gov.hhs.cdc.trustedintermediary.wrappers.Cache
import gov.hhs.cdc.trustedintermediary.wrappers.Logger
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.Formatter
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.FormatterProcessingException
import gov.hhs.cdc.trustedintermediary.wrappers.HapiFhir
import gov.hhs.cdc.trustedintermediary.wrappers.HttpClient
import gov.hhs.cdc.trustedintermediary.wrappers.HttpClientException
import gov.hhs.cdc.trustedintermediary.wrappers.Logger
import gov.hhs.cdc.trustedintermediary.wrappers.Secrets
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.Formatter
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.FormatterProcessingException
import gov.hhs.cdc.trustedintermediary.wrappers.formatter.TypeReference
import java.time.LocalDateTime
import java.time.temporal.ChronoUnit
import spock.lang.Specification

import java.util.concurrent.ConcurrentHashMap

class ReportStreamOrderSenderTest extends Specification {

def setup() {
Expand Down Expand Up @@ -310,61 +308,19 @@ class ReportStreamOrderSenderTest extends Specification {
def "ensure jwt that expires 15 seconds from now is valid"() {
given:
def mockAuthEngine = Mock(AuthEngine)
TestApplicationContext.register(AuthEngine, mockAuthEngine)

mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(20, ChronoUnit.SECONDS)
TestApplicationContext.register(OrderSender, ReportStreamOrderSender.getInstance())

TestApplicationContext.register(AuthEngine, mockAuthEngine)
TestApplicationContext.injectRegisteredImplementations()
ReportStreamOrderSender.getInstance().setRsTokenCache("our token from rs")

when:
def isValid = ReportStreamOrderSender.getInstance().isValidToken()
def isValid = ReportStreamOrderSender.getInstance().isValidToken("our token from rs")

then:
isValid
}

def "rsTokenCache getter and setter works, no synchronization"() {
given:
def rsOrderSender = ReportStreamOrderSender.getInstance()
def expected = "fake token"

when:
rsOrderSender.setRsTokenCache(expected)
def actual = rsOrderSender.getRsTokenCache()

then:
actual == expected
}

def "rsTokenCache synchronization works"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
def threadNums = 5
def iterations = 25
def table = new ConcurrentHashMap<String, Integer>()

when:
List<Thread> threads = []
(1..threadNums).each { threadId ->
threads.add(new Thread({
for(int i=0; i<iterations; i++) {
orderSender.setRsTokenCache("${i}")
if (i == 24) {
table.put("thread"+"${threadId}", i)
}
}
}))
}

threads*.start()
threads*.join()

then:
orderSender.getRsTokenCache() == "${iterations - 1}"
table.size() == threadNums
table.values().toSet().size() == 1
}

def "sendRequestBody bombs out due to http exception"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
Expand All @@ -385,97 +341,86 @@ class ReportStreamOrderSenderTest extends Specification {
exception.getCause().getClass() == HttpClientException
}

def "getRsToken when cache is empty"() {
def "getRsToken when cache is empty we call RS to get a new one"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
def mockClient = Mock(HttpClient)
def mockAuthEngine = Mock(AuthEngine)
def mockSecrets = Mock(Secrets)
def mockFormatter = Mock(Formatter)
def mockCache = Mock(Cache)

//make the cache empty
mockCache.get(_ as String) >> null

def freshTokenFromRs = "new token"
mockFormatter.convertJsonToObject(_, _ as TypeReference) >> [access_token: freshTokenFromRs]

TestApplicationContext.register(Formatter, mockFormatter)
TestApplicationContext.register(AuthEngine, mockAuthEngine)
TestApplicationContext.register(AuthEngine, Mock(AuthEngine))
TestApplicationContext.register(HttpClient, mockClient)
TestApplicationContext.register(Secrets, mockSecrets)
mockSecrets.getKey(_ as String) >> "fake private key"
TestApplicationContext.register(OrderSender, orderSender)
TestApplicationContext.injectRegisteredImplementations()
TestApplicationContext.register(Cache, mockCache)
TestApplicationContext.register(Secrets, Mock(Secrets))

mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS)
mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token"
mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token")
def responseBody = """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}"""
mockClient.post(_ as String, _ as Map, _ as String) >> responseBody
TestApplicationContext.injectRegisteredImplementations()

when:
def token = orderSender.getRsToken()
def token = ReportStreamOrderSender.getInstance().getRsToken()

then:
token == orderSender.getRsTokenCache()
1 * mockClient.post(_, _, _)
token == freshTokenFromRs
}

def "getRsToken when cache token is invalid"() {
def "getRsToken when cache token is invalid we call RS to get a new one"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
def mockClient = Mock(HttpClient)
def mockAuthEngine = Mock(AuthEngine)
def mockSecrets = Mock(Secrets)
def mockFormatter = Mock(Formatter)
def mockCache = Mock(Cache)

mockCache.get(_ as String) >> "shouldn't be returned"

//mock the auth engine so that the JWT looks like it is invalid
mockAuthEngine.getExpirationDate(_) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS)

def freshTokenFromRs = "new token"
mockFormatter.convertJsonToObject(_, _ as TypeReference) >> [access_token: freshTokenFromRs]

TestApplicationContext.register(Formatter, mockFormatter)
TestApplicationContext.register(AuthEngine, mockAuthEngine)
TestApplicationContext.register(HttpClient, mockClient)
TestApplicationContext.register(Secrets, mockSecrets)
mockSecrets.getKey(_ as String) >> "fakePrivateKey"
TestApplicationContext.register(OrderSender, orderSender)
TestApplicationContext.injectRegisteredImplementations()
TestApplicationContext.register(Cache, mockCache)
TestApplicationContext.register(Secrets, Mock(Secrets))

mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token"
mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(10, ChronoUnit.SECONDS)
mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token")
def responseBody = """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}"""
mockClient.post(_ as String, _ as Map, _ as String) >> responseBody
orderSender.setRsTokenCache("Invalid Token")
TestApplicationContext.injectRegisteredImplementations()

when:
def token = orderSender.getRsToken()
def token = ReportStreamOrderSender.getInstance().getRsToken()

then:
token == orderSender.getRsTokenCache()
1 * mockClient.post(_, _, _)
token == freshTokenFromRs
}

def "getRsToken when cache token is valid"() {
def "getRsToken when cache token is valid, return that cached token"() {
given:
def orderSender = ReportStreamOrderSender.getInstance()
orderSender.setRsTokenCache("valid Token")
TestApplicationContext.register(OrderSender, orderSender)
def mockAuthEngine = Mock(AuthEngine)
def mockCache = Mock(Cache)

def mockFormatter = Mock(Formatter)
mockFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> Map.of("access_token", "fake token")
TestApplicationContext.register(Formatter, mockFormatter)
def cachedRsToken = "DogCow goes Moof!"
mockCache.get(_ as String) >> cachedRsToken

def mockLogFormatter = Mock(Formatter)
mockLogFormatter.convertJsonToObject(_ as String, _ as TypeReference) >> null
TestApplicationContext.register(Formatter, mockLogFormatter)
//mock the auth engine so that the JWT looks valid
mockAuthEngine.getExpirationDate(_) >> LocalDateTime.now().plus(60, ChronoUnit.SECONDS)

def mockAuthEngine = Mock(AuthEngine)
mockAuthEngine.generateSenderToken(_ as String, _ as String, _ as String, _ as String, 300) >> "fake token"
mockAuthEngine.getExpirationDate(_ as String) >> LocalDateTime.now().plus(25, ChronoUnit.SECONDS)
TestApplicationContext.register(AuthEngine, mockAuthEngine)

def mockClient = Mock(HttpClient)
mockClient.post(_ as String, _ as Map, _ as String) >> """{"foo":"foo value", "access_token":fake token, "boo":"boo value"}"""
TestApplicationContext.register(HttpClient, mockClient)

def mockSecrets = Mock(Secrets)
mockSecrets.getKey(_ as String) >> "fakePrivateKey"
TestApplicationContext.register(Secrets, mockSecrets)
TestApplicationContext.register(Cache, mockCache)

TestApplicationContext.injectRegisteredImplementations()

when:
def token = orderSender.getRsToken()
def token = ReportStreamOrderSender.getInstance().getRsToken()

then:
token == orderSender.getRsTokenCache()
token == cachedRsToken
}

def "logRsSubmissionId logs submissionId if convertJsonToObject is successful"() {
Expand Down

0 comments on commit d6429cd

Please sign in to comment.