Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-862760: Add additionalHeader support #1490

Merged
17 changes: 17 additions & 0 deletions src/main/java/net/snowflake/client/core/HttpUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -899,4 +899,21 @@ static int convertSystemPropertyToIntValue(String systemProperty, int defaultVal
}
return returnVal;
}

/**
* Helper function to attach additional headers to a request if present. This takes a (nullable)
* map of headers in <name,value> format and adds them to the incoming request using addHeader.
*
* <p>Snowsight uses this to attach headers with additional telemetry information, see
* https://snowflakecomputing.atlassian.net/wiki/spaces/EN/pages/2960557006/GS+Communication
*
* @param request The request to add headers to. Must not be null.
* @param additionalHeaders The headers to add. May be null.
*/
static void applyAdditionalHeadersForSnowsight(
HttpRequestBase request, Map<String, String> additionalHeaders) {
if (additionalHeaders != null && !additionalHeaders.isEmpty()) {
additionalHeaders.forEach(request::addHeader);
}
}
}
38 changes: 38 additions & 0 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public class SFLoginInput {
private HttpClientSettingsKey httpClientKey;
private String privateKeyFile;
private String privateKeyFilePwd;
private String inFlightCtx; // Opaque string sent for Snowsight account activation

// Additional headers to add for Snowsight.
Map<String, String> additionalHttpHeadersForSnowsight;

SFLoginInput() {}

Expand Down Expand Up @@ -338,6 +342,40 @@ SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) {
return this;
}

// Opaque string sent for Snowsight account activation
String getInFlightCtx() {
return inFlightCtx;
}

// Opaque string sent for Snowsight account activation
SFLoginInput setInFlightCtx(String inFlightCtx) {
this.inFlightCtx = inFlightCtx;
return this;
}

Map<String, String> getAdditionalHttpHeadersForSnowsight() {
return additionalHttpHeadersForSnowsight;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
*/
public SFLoginInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
this.additionalHttpHeadersForSnowsight = additionalHttpHeaders;
return this;
}

static boolean getBooleanValue(Object v) {
if (v instanceof Boolean) {
return (Boolean) v;
Expand Down
14 changes: 14 additions & 0 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ private static SFLoginOutput newSession(

try {
ClientAuthnDTO authnData = new ClientAuthnDTO();
authnData.setInFlightCtx(loginInput.getInFlightCtx());

Map<String, Object> data = new HashMap<>();
data.put(ClientAuthnParameter.CLIENT_APP_ID.name(), loginInput.getAppId());

Expand Down Expand Up @@ -585,6 +587,10 @@ private static SFLoginOutput newSession(

postRequest = new HttpPost(loginURI);

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

// attach the login info json body to the post request
StringEntity input = new StringEntity(json, StandardCharsets.UTF_8);
input.setContentType("application/json");
Expand Down Expand Up @@ -894,6 +900,10 @@ private static SFLoginOutput tokenRequest(SFLoginInput loginInput, TokenRequestT
uriBuilder.addParameter(SFSession.SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString());

postRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());
} catch (URISyntaxException ex) {
logger.error("Exception when creating http request", ex);

Expand Down Expand Up @@ -1006,6 +1016,10 @@ static void closeSession(SFLoginInput loginInput) throws SFException, SnowflakeS

postRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

postRequest.setHeader(
SF_HEADER_AUTHORIZATION,
SF_HEADER_SNOWFLAKE_AUTHTYPE
Expand Down
32 changes: 32 additions & 0 deletions src/main/java/net/snowflake/client/core/StmtUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ static class StmtInput {

QueryContextDTO queryContextDTO;

Map<String, String> additionalHttpHeadersForSnowsight;

StmtInput() {}

public StmtInput setSql(String sql) {
Expand Down Expand Up @@ -236,6 +238,26 @@ public StmtInput setMaxRetries(int maxRetries) {
this.maxRetries = maxRetries;
return this;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot
* be used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
*/
@SuppressWarnings("unchecked")
public StmtInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
this.additionalHttpHeadersForSnowsight = additionalHttpHeaders;
return this;
}
}

/** Output for running a statement on server */
Expand Down Expand Up @@ -301,6 +323,10 @@ public static StmtOutput execute(StmtInput stmtInput, ExecTimeTelemetryData exec

httpRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

/*
* sequence id is only needed for old query API, when old query API
* is deprecated, we can remove sequence id.
Expand Down Expand Up @@ -590,6 +616,9 @@ protected static String getQueryResult(String getResultPath, StmtInput stmtInput
uriBuilder.addParameter(SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString());

httpRequest = new HttpGet(uriBuilder.build());
// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

httpRequest.addHeader("accept", stmtInput.mediaType);

Expand Down Expand Up @@ -691,6 +720,9 @@ public static void cancel(StmtInput stmtInput) throws SFException, SnowflakeSQLE
uriBuilder.addParameter(SF_QUERY_REQUEST_ID, UUIDUtils.getUUID().toString());

httpRequest = new HttpPost(uriBuilder.build());
// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

/*
* The JSON input has two fields: sqlText and requestId
Expand Down
136 changes: 136 additions & 0 deletions src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,31 @@

import static net.snowflake.client.TestUtil.systemGetEnv;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.UUID;
import net.snowflake.client.category.TestCategoryCore;
import net.snowflake.client.jdbc.BaseJDBCTest;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.common.core.ClientAuthnDTO;
import org.apache.commons.io.IOUtils;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpRequestBase;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.mockito.MockedStatic;
import org.mockito.MockedStatic.Verification;
import org.mockito.Mockito;

@Category(TestCategoryCore.class)
Expand Down Expand Up @@ -104,4 +115,129 @@ public void testConvertSystemPropertyToIntValue() {
assertEquals(
-1, HttpUtil.convertSystemPropertyToIntValue(HttpUtil.JDBC_TTL, HttpUtil.DEFAULT_TTL));
}

/**
* SNOW-862760 Tests that when additional headers are set on a login request, they are forwarded
* to the recipient.
*/
@Test
public void testForwardedHeaders() throws Throwable {
SFLoginInput input = createLoginInput();
Map<String, String> additionalHeaders = new HashMap<>();
additionalHeaders.put("Extra-Snowflake-Header", "present");

input.setAdditionalHttpHeadersForSnowsight(additionalHeaders);

Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
// Both mocks the call _and_ verifies that the headers are forwarded.
Verification httpCalledWithHeaders =
() ->
HttpUtil.executeGeneralRequest(
Mockito.argThat(
arg -> {
for (Entry<String, String> definedHeader : additionalHeaders.entrySet()) {
Header actualHeader = arg.getLastHeader(definedHeader.getKey());
if (actualHeader == null) {
return false;
}

if (!definedHeader.getValue().equals(actualHeader.getValue())) {
return false;
}
}

return true;
}),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class));
mockedHttpUtil
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

SessionUtil.openSession(input, connectionPropertiesMap, "ALL");

// After login, the only invocation to http should have been with the new
// headers.
// No calls should have happened without additional headers.
mockedHttpUtil.verify(times(1), httpCalledWithHeaders);
}
}

/**
* SNOW-862760 Verifies that, if inFlightCtx is provided to the login input, it's forwarded as
* part of the message body.
*/
@Test
public void testForwardInflightCtx() throws Throwable {
SFLoginInput input = createLoginInput();
String inflightCtx = UUID.randomUUID().toString();
input.setInFlightCtx(inflightCtx);

Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();

try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
// Both mocks the call _and_ verifies that the headers are forwarded.
Verification httpCalledWithHeaders =
() ->
HttpUtil.executeGeneralRequest(
Mockito.argThat(
arg -> {
try {
// This gets tricky because the entity is a string.
// To not fail on JSON parsing changes, we'll verify that the key
// inFlightCtx is present and the random UUID body
HttpEntity entity = ((HttpPost) arg).getEntity();
InputStream is = entity.getContent();
ByteArrayOutputStream out = new ByteArrayOutputStream();
IOUtils.copy(is, out);
String body = new String(out.toByteArray());
return body.contains("inFlightCtx") && body.contains(inflightCtx);
} catch (UnsupportedOperationException | IOException e) {
}
return false;
}),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.anyInt(),
Mockito.nullable(HttpClientSettingsKey.class));
mockedHttpUtil
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

SessionUtil.openSession(input, connectionPropertiesMap, "ALL");

// After login, the only invocation to http should have been with the new
// headers.
// No calls should have happened without additional headers.
mockedHttpUtil.verify(times(1), httpCalledWithHeaders);
}
}

private SFLoginInput createLoginInput() {
SFLoginInput input = new SFLoginInput();
input.setServerUrl("MOCK_TEST_HOST");
input.setUserName("MOCK_USERNAME");
input.setPassword("MOCK_PASSWORD");
input.setAccountName("MOCK_ACCOUNT_NAME");
input.setAppId("MOCK_APP_ID");
input.setOCSPMode(OCSPMode.FAIL_OPEN);
input.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN));
input.setLoginTimeout(1000);
input.setSessionParameters(new HashMap<>());

return input;
}
}
Loading