Skip to content

Commit

Permalink
Unroll the InputBase and rename headers
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vkatardjiev committed Aug 11, 2023
1 parent 75640a4 commit 29e9ff2
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/main/java/net/snowflake/client/core/HttpUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ static int convertSystemPropertyToIntValue(String systemProperty, int defaultVal
* @param request The request to add headers to. Must not be null.
* @param additionalHeaders The headers to add. May be null.
*/
static void applyAdditionalHeaders(
static void applyAdditionalHeadersForSnowsight(
HttpRequestBase request, Map<String, String> additionalHeaders) {
if (additionalHeaders != null && !additionalHeaders.isEmpty()) {
additionalHeaders.forEach(request::addHeader);
Expand Down
32 changes: 0 additions & 32 deletions src/main/java/net/snowflake/client/core/SFInputBase.java

This file was deleted.

32 changes: 30 additions & 2 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import net.snowflake.client.jdbc.ErrorCode;

/** A class for holding all information required for login */
public class SFLoginInput extends SFInputBase<SFLoginInput> {
public class SFLoginInput {
private static int DEFAULT_HTTP_CLIENT_CONNECTION_TIMEOUT = 60000; // millisec
private static int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT = 300000; // millisec

Expand Down Expand Up @@ -47,7 +47,10 @@ public class SFLoginInput extends SFInputBase<SFLoginInput> {
private HttpClientSettingsKey httpClientKey;
private String privateKeyFile;
private String privateKeyFilePwd;
private String inFlightCtx; // Used for Snowsight account activation
private String inFlightCtx; // Opaque string sent for Snowsight account activation

// Additional headers to add for Snowsight.
Map<String, String> additionalHttpHeadersForSnowsight;

SFLoginInput() {}

Expand Down Expand Up @@ -339,15 +342,40 @@ SFLoginInput setHttpClientSettingsKey(HttpClientSettingsKey key) {
return this;
}

// Opaque string sent for Snowsight account activation
String getInFlightCtx() {
return inFlightCtx;
}

// Opaque string sent for Snowsight account activation
SFLoginInput setInFlightCtx(String inFlightCtx) {
this.inFlightCtx = inFlightCtx;
return this;
}

Map<String, String> getAdditionalHttpHeadersForSnowsight() {
return additionalHttpHeadersForSnowsight;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
*/
public SFLoginInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
this.additionalHttpHeadersForSnowsight = additionalHttpHeaders;
return this;
}

static boolean getBooleanValue(Object v) {
if (v instanceof Boolean) {
return (Boolean) v;
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ private static SFLoginOutput newSession(
postRequest = new HttpPost(loginURI);

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeaders(postRequest, loginInput.getAdditionalHttpHeaders());
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

// attach the login info json body to the post request
StringEntity input = new StringEntity(json, StandardCharsets.UTF_8);
Expand Down Expand Up @@ -901,7 +902,8 @@ private static SFLoginOutput tokenRequest(SFLoginInput loginInput, TokenRequestT
postRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeaders(postRequest, loginInput.getAdditionalHttpHeaders());
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());
} catch (URISyntaxException ex) {
logger.error("Exception when creating http request", ex);

Expand Down Expand Up @@ -1015,7 +1017,8 @@ static void closeSession(SFLoginInput loginInput) throws SFException, SnowflakeS
postRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeaders(postRequest, loginInput.getAdditionalHttpHeaders());
HttpUtil.applyAdditionalHeadersForSnowsight(
postRequest, loginInput.getAdditionalHttpHeadersForSnowsight());

postRequest.setHeader(
SF_HEADER_AUTHORIZATION,
Expand Down
33 changes: 29 additions & 4 deletions src/main/java/net/snowflake/client/core/StmtUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public class StmtUtil {
static final SFLogger logger = SFLoggerFactory.getLogger(StmtUtil.class);

/** Input for executing a statement on server */
static class StmtInput extends SFInputBase<StmtInput> {
static class StmtInput {
String sql;

// default to snowflake (a special json format for snowflake query result
Expand Down Expand Up @@ -105,6 +105,8 @@ static class StmtInput extends SFInputBase<StmtInput> {

QueryContextDTO queryContextDTO;

Map<String, String> additionalHttpHeadersForSnowsight;

StmtInput() {}

public StmtInput setSql(String sql) {
Expand Down Expand Up @@ -236,6 +238,26 @@ public StmtInput setMaxRetries(int maxRetries) {
this.maxRetries = maxRetries;
return this;
}

/**
* Set additional http headers to apply to the outgoing request. The additional headers cannot
* be used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
*/
@SuppressWarnings("unchecked")
public StmtInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
this.additionalHttpHeadersForSnowsight = additionalHttpHeaders;
return this;
}
}

/** Output for running a statement on server */
Expand Down Expand Up @@ -302,7 +324,8 @@ public static StmtOutput execute(StmtInput stmtInput, ExecTimeTelemetryData exec
httpRequest = new HttpPost(uriBuilder.build());

// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeaders(httpRequest, stmtInput.additionalHttpHeaders);
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

/*
* sequence id is only needed for old query API, when old query API
Expand Down Expand Up @@ -594,7 +617,8 @@ protected static String getQueryResult(String getResultPath, StmtInput stmtInput

httpRequest = new HttpGet(uriBuilder.build());
// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeaders(httpRequest, stmtInput.additionalHttpHeaders);
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

httpRequest.addHeader("accept", stmtInput.mediaType);

Expand Down Expand Up @@ -697,7 +721,8 @@ public static void cancel(StmtInput stmtInput) throws SFException, SnowflakeSQLE

httpRequest = new HttpPost(uriBuilder.build());
// Add custom headers before adding common headers
HttpUtil.applyAdditionalHeaders(httpRequest, stmtInput.additionalHttpHeaders);
HttpUtil.applyAdditionalHeadersForSnowsight(
httpRequest, stmtInput.additionalHttpHeadersForSnowsight);

/*
* The JSON input has two fields: sqlText and requestId
Expand Down
10 changes: 7 additions & 3 deletions src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public void testForwardedHeaders() throws Throwable {
Map<String, String> additionalHeaders = new HashMap<>();
additionalHeaders.put("Extra-Snowflake-Header", "present");

input.setAdditionalHttpHeaders(additionalHeaders);
input.setAdditionalHttpHeadersForSnowsight(additionalHeaders);

Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
Expand Down Expand Up @@ -158,7 +158,9 @@ public void testForwardedHeaders() throws Throwable {
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil.when(() -> HttpUtil.applyAdditionalHeaders(any(), any())).thenCallRealMethod();
mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

SessionUtil.openSession(input, connectionPropertiesMap, "ALL");

Expand Down Expand Up @@ -211,7 +213,9 @@ public void testForwardInflightCtx() throws Throwable {
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":null,\"message\":null,\"success\":true}");

mockedHttpUtil.when(() -> HttpUtil.applyAdditionalHeaders(any(), any())).thenCallRealMethod();
mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

SessionUtil.openSession(input, connectionPropertiesMap, "ALL");

Expand Down
8 changes: 5 additions & 3 deletions src/test/java/net/snowflake/client/core/StmtUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void testForwardedHeaders() throws Throwable {
SFLoginInput input = createLoginInput();
Map<String, String> additionalHeaders = new HashMap<>();
additionalHeaders.put("Extra-Snowflake-Header", "present");
input.setAdditionalHttpHeaders(additionalHeaders);
input.setAdditionalHttpHeadersForSnowsight(additionalHeaders);

try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
// Both mocks the call _and_ verifies that the headers are forwarded.
Expand Down Expand Up @@ -66,10 +66,12 @@ public void testForwardedHeaders() throws Throwable {
.when(httpCalledWithHeaders)
.thenReturn("{\"data\":null,\"code\":333334,\"message\":null,\"success\":true}");

mockedHttpUtil.when(() -> HttpUtil.applyAdditionalHeaders(any(), any())).thenCallRealMethod();
mockedHttpUtil
.when(() -> HttpUtil.applyAdditionalHeadersForSnowsight(any(), any()))
.thenCallRealMethod();

StmtInput stmtInput = new StmtInput();
stmtInput.setAdditionalHttpHeaders(additionalHeaders);
stmtInput.setAdditionalHttpHeadersForSnowsight(additionalHeaders);
// Async mode skips result post-processing so we don't need to mock an advanced
// response
stmtInput.setAsync(true);
Expand Down

0 comments on commit 29e9ff2

Please sign in to comment.