Skip to content

Commit

Permalink
fix auth flow, error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
zsimjee committed May 22, 2024
1 parent ae79ac9 commit 7dc0819
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
24 changes: 12 additions & 12 deletions guardrails/cli/configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,18 @@ def configure(
):
if clear_token is True:
token = DEFAULT_TOKEN
try:
save_configuration_file(token, enable_metrics)
logger.info("Configuration saved.")

if not token:
logger.info("No token provided. Skipping authentication.")
except Exception as e:
logger.error("An unexpected error occured!")
logger.error(e)
sys.exit(1)

# Authenticate with the Hub if token is not empty
# Authenticate with the Hub if token is not empty
if token != "" and token is not None:
logger.info("Validating credentials...")
try:
Expand All @@ -85,16 +95,6 @@ def configure(
"""
logger.log(level=LEVELS.get("SUCCESS", 25), msg=success_message)
except AuthenticationError as e:
logger.warn(e)
logger.error(e)
# We do not want to exit the program if the user fails to authenticate
# instead, save the token and other configuration options
try:
save_configuration_file(token, enable_metrics)
logger.info("Configuration saved.")

if not token:
logger.info("No token provided. Skipping authentication.")
except Exception as e:
logger.error("An unexpected error occured!")
logger.error(e)
sys.exit(1)
31 changes: 23 additions & 8 deletions guardrails/cli/server/hub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from guardrails.cli.logger import logger
from guardrails.cli.server.module_manifest import ModuleManifest

TOKEN_EXPIRED_MESSAGE = (
"Your token has expired. Please run `guardrails configure` to update your token."
)
TOKEN_INVALID_MESSAGE = (
"Your token is invalid. Please run `guardrails configure` to update your token."
)
FIND_NEW_TOKEN = "You can find a new token at https://hub.guardrailsai.com/tokens"

TOKEN_EXPIRED_MESSAGE = f"""Your token has expired. Please run `guardrails configure`\
to update your token.
{FIND_NEW_TOKEN}"""
TOKEN_INVALID_MESSAGE = f"""Your token is invalid. Please run `guardrails configure`\
to update your token.
{FIND_NEW_TOKEN}"""

validator_hub_service = "https://so4sg4q4pb.execute-api.us-east-1.amazonaws.com"
validator_manifest_endpoint = Template(
Expand All @@ -27,6 +29,14 @@ class AuthenticationError(Exception):
pass


class ExpiredTokenError(Exception):
pass


class InvalidTokenError(Exception):
pass


class HttpError(Exception):
status: int
message: str
Expand Down Expand Up @@ -80,9 +90,9 @@ def get_jwt_token(creds: Credentials) -> Optional[str]:
except JWTDecodeError as e:
# if the error message includes "Expired", then the token is expired
if "Expired" in str(e):
raise Exception(TOKEN_EXPIRED_MESSAGE)
raise ExpiredTokenError(TOKEN_EXPIRED_MESSAGE)
else:
raise Exception(TOKEN_INVALID_MESSAGE)
raise InvalidTokenError(TOKEN_INVALID_MESSAGE)
return token


Expand All @@ -105,6 +115,9 @@ def get_validator_manifest(module_name: str):
except HttpError:
logger.error(f"Failed to install hub://{module_name}")
sys.exit(1)
except (ExpiredTokenError, InvalidTokenError) as e:
logger.error(AuthenticationError(e))
sys.exit(1)
except Exception as e:
logger.error("An unexpected error occurred!", e)
sys.exit(1)
Expand All @@ -122,6 +135,8 @@ def get_auth():
except HttpError as http_error:
logger.error(http_error)
raise AuthenticationError("Failed to authenticate!")
except (ExpiredTokenError, InvalidTokenError) as e:
raise AuthenticationError(e)
except Exception as e:
logger.error("An unexpected error occurred!", e)
raise AuthenticationError("Failed to authenticate!")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/cli/test_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_configure(mocker, token, no_metrics):
configure(token, no_metrics)

assert mock_logger_info.call_count == 2
expected_calls = [call("Validating credentials..."), call("Configuration saved.")]
expected_calls = [call("Configuration saved."), call("Validating credentials...")]
mock_logger_info.assert_has_calls(expected_calls)

mock_save_configuration_file.assert_called_once_with(token, no_metrics)
Expand Down

0 comments on commit 7dc0819

Please sign in to comment.