Skip to content

Commit

Permalink
chore: Add param support for client_request_mfa_token (apache#40394)
Browse files Browse the repository at this point in the history
* chore: Add param support for client_reqquest_mfa_token

* chore: Update test

* chore: Fix test param beat extra

* chore: Update get_params for client_request_mfa_token

* chore: Update client_request_mfa_token to be added if it is set to be True

* chore: Add client_request_mfa_token back to the test

* chore: Update the test as boolean for cient_request_mfa_token

* chore: Cleanup

* chore: Exclude client request mfa token as part of uri

* style: Lint
  • Loading branch information
vanducng authored Jun 25, 2024
1 parent 19baf9c commit c310159
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
10 changes: 9 additions & 1 deletion airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"authenticator": "snowflake oauth",
"private_key_file": "private key",
"session_parameters": "session parameters",
"client_request_mfa_token": "client request mfa token",
},
indent=1,
),
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(self, *args, **kwargs) -> None:
self.schema = kwargs.pop("schema", None)
self.authenticator = kwargs.pop("authenticator", None)
self.session_parameters = kwargs.pop("session_parameters", None)
self.client_request_mfa_token = kwargs.pop("client_request_mfa_token", None)
self.query_ids: list[str] = []

def _get_field(self, extra_dict, field_name):
Expand Down Expand Up @@ -194,6 +196,7 @@ def _get_conn_params(self) -> dict[str, str | None]:
role = self._get_field(extra_dict, "role") or ""
insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode"))
schema = conn.schema or ""
client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token"))

# authenticator and session_parameters never supported long name so we don't use _get_field
authenticator = extra_dict.get("authenticator", "snowflake")
Expand All @@ -216,6 +219,9 @@ def _get_conn_params(self) -> dict[str, str | None]:
if insecure_mode:
conn_config["insecure_mode"] = insecure_mode

if client_request_mfa_token:
conn_config["client_request_mfa_token"] = client_request_mfa_token

# If private_key_file is specified in the extra json, load the contents of the file as a private key.
# If private_key_content is specified in the extra json, use it as a private key.
# As a next step, specify this private key in the connection configuration.
Expand Down Expand Up @@ -280,7 +286,9 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str:
**{
k: v
for k, v in conn_params.items()
if v and k not in ["session_parameters", "insecure_mode", "private_key"]
if v
and k
not in ["session_parameters", "insecure_mode", "private_key", "client_request_mfa_token"]
}
)

Expand Down
4 changes: 4 additions & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TestPytestSnowflakeHook:
"extra__snowflake__region": "af_region",
"extra__snowflake__role": "af_role",
"extra__snowflake__insecure_mode": "True",
"extra__snowflake__client_request_mfa_token": "True",
},
},
(
Expand All @@ -156,6 +157,7 @@ class TestPytestSnowflakeHook:
"user": "user",
"warehouse": "af_wh",
"insecure_mode": True,
"client_request_mfa_token": True,
},
),
(
Expand All @@ -168,6 +170,7 @@ class TestPytestSnowflakeHook:
"extra__snowflake__region": "af_region",
"extra__snowflake__role": "af_role",
"extra__snowflake__insecure_mode": "False",
"extra__snowflake__client_request_mfa_token": "False",
},
},
(
Expand Down Expand Up @@ -243,6 +246,7 @@ class TestPytestSnowflakeHook:
"extra": {
**BASE_CONNECTION_KWARGS["extra"],
"extra__snowflake__insecure_mode": False,
"extra__snowflake__client_request_mfa_token": False,
},
},
(
Expand Down

0 comments on commit c310159

Please sign in to comment.