From 1abfb24811b338e44251dd8ab608dde250e40c48 Mon Sep 17 00:00:00 2001 From: Andy Kuny Date: Tue, 5 Mar 2024 09:34:40 -0500 Subject: [PATCH 1/2] Add migration and seed scripts --- ...remove_username_column_from_users_table.py | 29 ++++++++++ nad_ch/application/use_cases/auth.py | 1 - nad_ch/controllers/web/templates/index.html | 2 +- nad_ch/domain/entities.py | 5 +- nad_ch/infrastructure/auth.py | 7 +++ nad_ch/infrastructure/database.py | 3 -- pyproject.toml | 3 ++ scripts/format.py | 2 +- scripts/migrate_down.py | 45 ++++++++++++++++ scripts/migrate_up.py | 19 +++++++ scripts/seed.py | 54 +++++++++++++++++++ tests/application/use_cases/test_auth.py | 5 +- 12 files changed, 162 insertions(+), 13 deletions(-) create mode 100644 alembic/versions/dc3dd97eae46_remove_username_column_from_users_table.py create mode 100644 scripts/migrate_down.py create mode 100644 scripts/migrate_up.py create mode 100644 scripts/seed.py diff --git a/alembic/versions/dc3dd97eae46_remove_username_column_from_users_table.py b/alembic/versions/dc3dd97eae46_remove_username_column_from_users_table.py new file mode 100644 index 0000000..27e61dd --- /dev/null +++ b/alembic/versions/dc3dd97eae46_remove_username_column_from_users_table.py @@ -0,0 +1,29 @@ +"""Remove username column from users table + +Revision ID: dc3dd97eae46 +Revises: 945ca77479d1 +Create Date: 2024-03-05 13:21:29.812837 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "dc3dd97eae46" +down_revision: Union[str, None] = "945ca77479d1" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade(): + op.drop_column("users", "username") + + +def downgrade(): + op.add_column( + "users", + sa.Column("username", sa.String, nullable=True), + ) diff --git a/nad_ch/application/use_cases/auth.py b/nad_ch/application/use_cases/auth.py index cc140c0..8f4da18 100644 --- a/nad_ch/application/use_cases/auth.py +++ b/nad_ch/application/use_cases/auth.py @@ -16,7 +16,6 @@ def get_or_create_user(ctx: ApplicationContext, provider_name: str, email: str) new_user = User( email=email, - username=email.split("@")[0], login_provider=provider_name, logout_url=ctx.auth.get_logout_url(provider_name), ) diff --git a/nad_ch/controllers/web/templates/index.html b/nad_ch/controllers/web/templates/index.html index 592df03..460f032 100644 --- a/nad_ch/controllers/web/templates/index.html +++ b/nad_ch/controllers/web/templates/index.html @@ -13,7 +13,7 @@
{% if current_user.is_authenticated %} -

Hi, {{ current_user.username }}!

+

Hi, {{ current_user.email }}!

Thanks for logging in with {{ current_user.login_provider }}.

{% else %}
    diff --git a/nad_ch/domain/entities.py b/nad_ch/domain/entities.py index 3759746..ba708c1 100644 --- a/nad_ch/domain/entities.py +++ b/nad_ch/domain/entities.py @@ -63,9 +63,8 @@ def has_report(self) -> bool: class User(Entity): - def __init__(self, username, email, login_provider, logout_url, id: int = None): + def __init__(self, email, login_provider, logout_url, id: int = None): super().__init__(id) - self.username = username self.email = email self.login_provider = login_provider self.logout_url = logout_url @@ -91,4 +90,4 @@ def get_id(self) -> str: raise NotImplementedError("No `id` attribute - override `get_id`") from None def __repr__(self): - return f"User {self.id}, {self.username}, {self.email})" + return f"User {self.id}, {self.email})" diff --git a/nad_ch/infrastructure/auth.py b/nad_ch/infrastructure/auth.py index 9e14462..f28ae32 100644 --- a/nad_ch/infrastructure/auth.py +++ b/nad_ch/infrastructure/auth.py @@ -69,6 +69,13 @@ def fetch_user_email_from_login_provider( return None + def get_logout_url(self, provider_name: str) -> str | None: + provider_config = self._providers[provider_name] + if not provider_config: + return None + + return provider_config["logout_url"] + def make_login_url(self, provider_name: str, state_token: str) -> str | None: provider_config = self._providers[provider_name] if not provider_config: diff --git a/nad_ch/infrastructure/database.py b/nad_ch/infrastructure/database.py index 858dd00..94119ba 100644 --- a/nad_ch/infrastructure/database.py +++ b/nad_ch/infrastructure/database.py @@ -113,7 +113,6 @@ def to_entity(self, producer: DataProducer): class UserModel(UserMixin, CommonBase): __tablename__ = "users" - username = Column(String) email = Column(String) login_provider = Column(String) logout_url = Column(String) @@ -123,7 +122,6 @@ def from_entity(user): model = UserModel( id=user.id, email=user.email, - username=user.username, login_provider=user.login_provider, logout_url=user.logout_url, ) @@ -133,7 +131,6 @@ def to_entity(self): entity = User( id=self.id, email=self.email, - username=self.username, login_provider=self.login_provider, logout_url=self.logout_url, ) diff --git a/pyproject.toml b/pyproject.toml index f526e2b..90b231a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,9 @@ pytest-env = "^1.1.3" cli = "nad_ch.main:run_cli" format = "scripts.format:main" lint = "flake8.main.cli:main" +migrate_down = "scripts.migrate_down:main" +migrate_up = "scripts.migrate_up:main" +seed = "scripts.seed:main" start-web = "nad_ch.main:serve_flask_app" test = "pytest:main" diff --git a/scripts/format.py b/scripts/format.py index e688851..4a9d0ae 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -2,7 +2,7 @@ def main(): - subprocess.run(["black", "./nad_ch", "./alembic", "./tests"]) + subprocess.run(["black", "./alembic", "./nad_ch", "./scripts", "./tests"]) if __name__ == "__main__": diff --git a/scripts/migrate_down.py b/scripts/migrate_down.py new file mode 100644 index 0000000..1319d3d --- /dev/null +++ b/scripts/migrate_down.py @@ -0,0 +1,45 @@ +import os +from alembic.config import Config +from alembic import command +from boto3.session import Session +from botocore.client import Config as BotocoreConfig +from nad_ch.config.development_local import ( + S3_ENDPOINT, + S3_ACCESS_KEY, + S3_SECRET_ACCESS_KEY, + S3_BUCKET_NAME, +) + + +def main(): + if os.getenv("APP_ENV") != "dev_local": + raise Exception("This script can only be run in a local dev environment.") + + current_script_path = os.path.abspath(__file__) + project_root = os.path.dirname(os.path.dirname(current_script_path)) + alembic_cfg_path = os.path.join(project_root, "alembic.ini") + + alembic_cfg = Config(alembic_cfg_path) + command.downgrade(alembic_cfg, "base") + + # flush storage + session = Session() + minio_client = session.client( + "s3", + endpoint_url=S3_ENDPOINT, + aws_access_key_id=S3_ACCESS_KEY, + aws_secret_access_key=S3_SECRET_ACCESS_KEY, + aws_session_token=None, + region_name="us-east-1", + verify=False, + config=BotocoreConfig(signature_version="s3v4"), + ) + response = minio_client.list_objects_v2(Bucket=S3_BUCKET_NAME) + + for object in response["Contents"]: + print("Deleting", object["Key"]) + minio_client.delete_object(Bucket=S3_BUCKET_NAME, Key=object["Key"]) + + +if __name__ == "__main__": + main() diff --git a/scripts/migrate_up.py b/scripts/migrate_up.py new file mode 100644 index 0000000..ea838d4 --- /dev/null +++ b/scripts/migrate_up.py @@ -0,0 +1,19 @@ +import os +from alembic.config import Config +from alembic import command + + +def main(): + if os.getenv("APP_ENV") != "dev_local": + raise Exception("This script can only be run in a local dev environment.") + + current_script_path = os.path.abspath(__file__) + project_root = os.path.dirname(os.path.dirname(current_script_path)) + alembic_cfg_path = os.path.join(project_root, "alembic.ini") + + alembic_cfg = Config(alembic_cfg_path) + command.upgrade(alembic_cfg, "head") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed.py b/scripts/seed.py new file mode 100644 index 0000000..608f458 --- /dev/null +++ b/scripts/seed.py @@ -0,0 +1,54 @@ +import os +import zipfile +from nad_ch.config import create_app_context, OAUTH2_CONFIG +from nad_ch.domain.entities import DataProducer, DataSubmission, User + + +def zip_directory(folder_path, zip_path): + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, dirs, files in os.walk(folder_path): + for file in files: + zipf.write( + os.path.join(root, file), + os.path.relpath( + os.path.join(root, file), os.path.join(folder_path, "..") + ), + ) + + +def main(): + if os.getenv("APP_ENV") != "dev_local": + raise Exception("This script can only be run in a local dev environment.") + + ctx = create_app_context() + + new_producer = DataProducer(name="New Jersey") + saved_producer = ctx.producers.add(new_producer) + + new_user = User( + email="test@test.org", + login_provider="cloudgov", + logout_url=OAUTH2_CONFIG["cloudgov"]["logout_url"], + ) + saved_user = ctx.users.add(new_user) + + current_script_path = os.path.abspath(__file__) + project_root = os.path.dirname(os.path.dirname(current_script_path)) + gdb_path = os.path.join( + project_root, "tests", "test_data", "geodatabases", "Naperville.gdb" + ) + zipped_gdb_path = os.path.join( + project_root, "tests", "test_data", "geodatabases", "Naperville.gdb.zip" + ) + zip_directory(gdb_path, zipped_gdb_path) + + filename = DataSubmission.generate_filename(zipped_gdb_path, saved_producer) + ctx.storage.upload(zipped_gdb_path, filename) + new_submission = DataSubmission(filename, saved_producer) + saved_submission = ctx.submissions.add(new_submission) + + os.remove(zipped_gdb_path) + + +if __name__ == "__main__": + main() diff --git a/tests/application/use_cases/test_auth.py b/tests/application/use_cases/test_auth.py index acdca3a..91e9dc9 100644 --- a/tests/application/use_cases/test_auth.py +++ b/tests/application/use_cases/test_auth.py @@ -21,9 +21,7 @@ def test_get_or_create_user_existing_user(app_context): app_context.auth.make_login_url = lambda x: "test" email = "johnny@test.org" login_provider = "test" - user = User( - username="johnny", email=email, login_provider=login_provider, logout_url="test" - ) + user = User(email=email, login_provider=login_provider, logout_url="test") app_context.users.add(user) result = get_or_create_user(app_context, login_provider, email) assert result == user @@ -35,7 +33,6 @@ def test_get_or_create_user_new_user(app_context): result = get_or_create_user(app_context, login_provider, email) assert isinstance(result, User) assert result.email == email - assert result.username == "johnny" assert result.login_provider == login_provider From a755f6b1b764132a37c3ce19860aaeb7c8c03f27 Mon Sep 17 00:00:00 2001 From: Andy Kuny Date: Tue, 5 Mar 2024 09:39:19 -0500 Subject: [PATCH 2/2] Address linting errors --- scripts/seed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/seed.py b/scripts/seed.py index 608f458..bbd95c4 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -30,7 +30,7 @@ def main(): login_provider="cloudgov", logout_url=OAUTH2_CONFIG["cloudgov"]["logout_url"], ) - saved_user = ctx.users.add(new_user) + ctx.users.add(new_user) current_script_path = os.path.abspath(__file__) project_root = os.path.dirname(os.path.dirname(current_script_path)) @@ -45,7 +45,7 @@ def main(): filename = DataSubmission.generate_filename(zipped_gdb_path, saved_producer) ctx.storage.upload(zipped_gdb_path, filename) new_submission = DataSubmission(filename, saved_producer) - saved_submission = ctx.submissions.add(new_submission) + ctx.submissions.add(new_submission) os.remove(zipped_gdb_path)