Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduced retry mechanism for Choreo API endpint request #146

Merged
merged 3 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.wso2.carbon.identity.conditional.auth.functions.choreo.internal.ChoreoFunctionServiceHolder;
import org.wso2.carbon.identity.conditional.auth.functions.common.utils.ConfigProvider;
import org.wso2.carbon.identity.conditional.auth.functions.common.utils.Constants;
import org.wso2.carbon.identity.core.util.IdentityUtil;
import org.wso2.carbon.identity.secret.mgt.core.exception.SecretManagementClientException;
import org.wso2.carbon.identity.secret.mgt.core.exception.SecretManagementException;
import org.wso2.carbon.identity.secret.mgt.core.model.ResolvedSecret;
Expand Down Expand Up @@ -100,13 +101,24 @@ public class CallChoreoFunctionImpl implements CallChoreoFunction {
private static final String BEARER = "Bearer ";
private static final String BASIC = "Basic ";
private static final int MAX_TOKEN_REQUEST_ATTEMPTS = 2;
private int maxTokenRequestAttemptsForTimeOut = 2;
private int maxRequestAttemptsForChoreoAPIEndpointTimeout = 2;

private final ChoreoAccessTokenCache choreoAccessTokenCache;

public CallChoreoFunctionImpl() {

this.choreoDomains = ConfigProvider.getInstance().getChoreoDomains();
this.choreoAccessTokenCache = ChoreoAccessTokenCache.getInstance();

if (StringUtils.isNotBlank(IdentityUtil.getProperty(Constants.CALL_CHOREO_TOKEN_REQUEST_RETRY_COUNT))) {
maxTokenRequestAttemptsForTimeOut = Integer.parseInt
(IdentityUtil.getProperty(Constants.CALL_CHOREO_TOKEN_REQUEST_RETRY_COUNT));
}
if (StringUtils.isNotBlank(IdentityUtil.getProperty(Constants.CALL_CHOREO_API_REQUEST_RETRY_COUNT))) {
maxRequestAttemptsForChoreoAPIEndpointTimeout = Integer.parseInt
(IdentityUtil.getProperty(Constants.CALL_CHOREO_API_REQUEST_RETRY_COUNT));
}
}

@Override
Expand Down Expand Up @@ -252,7 +264,7 @@ private String getParentDomainFromUrl(String url) throws URISyntaxException {
* @throws FrameworkException {@link FrameworkException}
*/
private void requestAccessToken(String tenantDomain, AccessTokenRequestHelper accessTokenRequestHelper)
throws IOException, FrameworkException {
throws IOException, FrameworkException {

String tokenEndpoint;
if (StringUtils.isNotEmpty(accessTokenRequestHelper.getAsgardeoTokenEndpoint())) {
Expand All @@ -266,7 +278,7 @@ private void requestAccessToken(String tenantDomain, AccessTokenRequestHelper ac

request.setHeader(AUTHORIZATION, BASIC + Base64.getEncoder()
.encodeToString((accessTokenRequestHelper.consumerKey + ":" + accessTokenRequestHelper.consumerSecret)
.getBytes(StandardCharsets.UTF_8)));
.getBytes(StandardCharsets.UTF_8)));
Sachin-Mamoru marked this conversation as resolved.
Show resolved Hide resolved

List<BasicNameValuePair> bodyParams = new ArrayList<>();
bodyParams.add(new BasicNameValuePair(GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS));
Expand All @@ -288,6 +300,7 @@ private class AccessTokenRequestHelper implements FutureCallback<HttpResponse> {
private final Gson gson;
private final AtomicInteger tokenRequestAttemptCount;
private final AtomicInteger tokenRequestAttemptCountForTimeOut;
private final AtomicInteger requestAttemptCountForChoreoAPIEndpointTimeOut;
private String consumerKey;
private String consumerSecret;
private String asgardeoTokenEndpoint;
Expand All @@ -304,6 +317,7 @@ public AccessTokenRequestHelper(Map<String, String> connectionMetaData,
this.gson = new GsonBuilder().create();
this.tokenRequestAttemptCount = new AtomicInteger(0);
this.tokenRequestAttemptCountForTimeOut = new AtomicInteger(0);
this.requestAttemptCountForChoreoAPIEndpointTimeOut = new AtomicInteger(0);
resolveConsumerKeySecrete();
}

Expand Down Expand Up @@ -366,15 +380,16 @@ public void completed(HttpResponse httpResponse) {
@Override
public void failed(Exception e) {

LOG.error("Failed to request access token from Choreo for the session data key: " +
LOG.warn("Failed to request access token from Choreo for the session data key: " +
authenticationContext.getContextIdentifier(), e);
try {
String outcome = OUTCOME_FAIL;
if ((e instanceof SocketTimeoutException) || (e instanceof ConnectTimeoutException)) {
outcome = OUTCOME_TIMEOUT;
}
// Retry if the access token request failed due to a timeout or failed scenario.
handleRetryTokenRequest(tokenRequestAttemptCountForTimeOut, outcome);
handleRetryTokenRequest(tokenRequestAttemptCountForTimeOut, outcome,
maxTokenRequestAttemptsForTimeOut);
} catch (Exception ex) {
LOG.error("Error while proceeding after failing to request access token for the session data key: " +
authenticationContext.getContextIdentifier(), e);
Expand Down Expand Up @@ -418,7 +433,7 @@ private void callChoreoEndpoint(String accessToken) {
.getClient(this.authenticationContext.getTenantDomain());
LOG.info("Calling Choreo endpoint for session data key: " +
authenticationContext.getContextIdentifier());
client.execute(request, new FutureCallback<HttpResponse>() {
FutureCallback<HttpResponse> callChoreoEndpointCallback = new FutureCallback<HttpResponse>() {

@Override
public void completed(final HttpResponse response) {
Expand All @@ -436,14 +451,26 @@ public void completed(final HttpResponse response) {
@Override
public void failed(final Exception ex) {

LOG.error("Failed to invoke Choreo for session data key: " +
LOG.warn("Failed to invoke Choreo for session data key: " +
authenticationContext.getContextIdentifier(), ex);
try {
String outcome = Constants.OUTCOME_FAIL;
if ((ex instanceof SocketTimeoutException) || (ex instanceof ConnectTimeoutException)) {
outcome = Constants.OUTCOME_TIMEOUT;
}
asyncReturn.accept(authenticationContext, Collections.emptyMap(), outcome);

if (requestAttemptCountForChoreoAPIEndpointTimeOut
.get() < maxRequestAttemptsForChoreoAPIEndpointTimeout) {
LOG.info("Retrying request for session data key: " +
authenticationContext.getContextIdentifier());
client.execute(request, this);
requestAttemptCountForChoreoAPIEndpointTimeOut.incrementAndGet();
} else {
LOG.warn("Maximum request attempt count exceeded for session data key: " +
authenticationContext.getContextIdentifier());
requestAttemptCountForChoreoAPIEndpointTimeOut.set(0);
asyncReturn.accept(authenticationContext, Collections.emptyMap(), outcome);
}
} catch (Exception e) {
LOG.error("Error while proceeding after failed response from Choreo " +
"call for session data key: " + authenticationContext.getContextIdentifier(), e);
Expand All @@ -462,7 +489,9 @@ public void cancelled() {
"data key: " + authenticationContext.getContextIdentifier(), e);
}
}
});
};

client.execute(request, callChoreoEndpointCallback);
} catch (UnsupportedEncodingException e) {
LOG.error("Error while constructing request payload for calling choreo endpoint. session data key: " +
authenticationContext.getContextIdentifier(), e);
Expand Down Expand Up @@ -519,7 +548,7 @@ private void handleChoreoEndpointResponse(final HttpResponse response) throws Fr
if (ERROR_CODE_ACCESS_TOKEN_INACTIVE.equals(responseBody.get(CODE))) {
LOG.info("Access token inactive for session data key: " +
authenticationContext.getContextIdentifier());
handleRetryTokenRequest(tokenRequestAttemptCount, OUTCOME_FAIL);
handleRetryTokenRequest(tokenRequestAttemptCount, OUTCOME_FAIL, MAX_TOKEN_REQUEST_ATTEMPTS);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
handleRetryTokenRequest(tokenRequestAttemptCount, OUTCOME_FAIL, MAX_TOKEN_REQUEST_ATTEMPTS);
handleRetryTokenRequest(tokenRequestAttemptCount, OUTCOME_FAIL, maxTokenRequestAttemptsForTimeOut);

is the variable correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the retry mechanism is for expired tokens which they had earlier.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can also be the same as a token timeout retry. IMO expired or timeout retry attempts should be configurable. But for the moment let's keep this.

} else {
LOG.warn("Received 401 response from Choreo. Session data key: " +
authenticationContext.getContextIdentifier());
Expand All @@ -546,13 +575,16 @@ private void handleChoreoEndpointResponse(final HttpResponse response) throws Fr
* token or if it's a time-out. The program will retry the token request flow until it exceeds the specified
* max request attempt count.
*
* @param tokenRequestAttemptCount {@link AtomicInteger}
* @param outcome {@link String}
* @param maxTokenRequestAttempts {@link Integer}
* @throws IOException {@link IOException}
* @throws FrameworkException {@link FrameworkException}
*/
private void handleRetryTokenRequest(AtomicInteger tokenRequestAttemptCount, String outcome)
throws IOException, FrameworkException {
private void handleRetryTokenRequest(AtomicInteger tokenRequestAttemptCount, String outcome,
int maxTokenRequestAttempts) throws IOException, FrameworkException {

if (tokenRequestAttemptCount.get() < MAX_TOKEN_REQUEST_ATTEMPTS) {
if (tokenRequestAttemptCount.get() < maxTokenRequestAttempts) {
LOG.info("Retrying token request for session data key: " +
this.authenticationContext.getContextIdentifier());
requestAccessToken(this.authenticationContext.getTenantDomain(), this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ public class Constants {

public static final String CALL_CHOREO_HTTP_CONNECTION_REQUEST_TIMEOUT = "AdaptiveAuth.CallChoreo.HTTPConnectionRequestTimeout";
public static final String CALL_CHOREO_HTTP_READ_TIMEOUT = "AdaptiveAuth.CallChoreo.HTTPReadTimeout";
public static final String CALL_CHOREO_TOKEN_REQUEST_RETRY_COUNT = "AdaptiveAuth.CallChoreo.TokenRequestRetryCount";
public static final String CALL_CHOREO_API_REQUEST_RETRY_COUNT = "AdaptiveAuth.CallChoreo.ChoreoAPIRequestRetryCount";

public static final String HTTP_FUNCTION_ALLOWED_DOMAINS = "AdaptiveAuth.HTTPFunctionAllowedDomains.Domain";
public static final String CHOREO_DOMAINS = "AdaptiveAuth.ChoreoDomains.Domain";
Expand Down