diff --git a/alembic/versions/9bb8e29b98fa_user_dataproducer_relationship.py b/alembic/versions/9bb8e29b98fa_user_dataproducer_relationship.py new file mode 100644 index 0000000..d3c226f --- /dev/null +++ b/alembic/versions/9bb8e29b98fa_user_dataproducer_relationship.py @@ -0,0 +1,39 @@ +"""User DataProducer relationship + +Revision ID: 9bb8e29b98fa +Revises: 68982ccf2c7c +Create Date: 2024-03-27 14:21:24.094899 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "9bb8e29b98fa" +down_revision: Union[str, None] = "68982ccf2c7c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("activated", sa.Boolean, nullable=False, default=False), + ) + + op.add_column( + "users", + sa.Column( + "data_producer_id", + sa.Integer, + sa.ForeignKey("data_producers.id"), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("users", "activated") diff --git a/nad_ch/application/use_cases/column_maps.py b/nad_ch/application/use_cases/column_maps.py index 5f4983a..1a31732 100644 --- a/nad_ch/application/use_cases/column_maps.py +++ b/nad_ch/application/use_cases/column_maps.py @@ -17,7 +17,7 @@ def add_column_map( raise ValueError("User not found") # TODO get the producer name from the user's producer property - producer = ctx.producers.get_by_name("New Jersey") + producer = user.producer if producer is None: raise ValueError("Producer not found") diff --git a/nad_ch/controllers/web/routes/column_maps.py b/nad_ch/controllers/web/routes/column_maps.py index 676ffee..fc04da3 100644 --- a/nad_ch/controllers/web/routes/column_maps.py +++ b/nad_ch/controllers/web/routes/column_maps.py @@ -33,7 +33,7 @@ def before_request(): @login_required def index(): try: - view_models = get_column_maps_by_producer(g.ctx, "New Jersey") + view_models = get_column_maps_by_producer(g.ctx, current_user.producer.name) return render_template("column_maps/index.html", column_maps=view_models) except ValueError: abort(404) diff --git a/nad_ch/controllers/web/routes/data_submissions.py b/nad_ch/controllers/web/routes/data_submissions.py index be5eaa7..95a79b0 100644 --- a/nad_ch/controllers/web/routes/data_submissions.py +++ b/nad_ch/controllers/web/routes/data_submissions.py @@ -1,5 +1,5 @@ from flask import Blueprint, current_app, render_template, g -from flask_login import login_required +from flask_login import login_required, current_user from nad_ch.application.use_cases.data_submissions import ( get_data_submission, list_data_submissions_by_producer, @@ -30,7 +30,7 @@ def show(submission_id): @login_required def reports(): # For demo purposes, hard-code the producer name - view_model = list_data_submissions_by_producer(g.ctx, "New Jersey") + view_model = list_data_submissions_by_producer(g.ctx, current_user.producer.name) return render_template("data_submissions/index.html", submissions=view_model) diff --git a/nad_ch/core/entities.py b/nad_ch/core/entities.py index e02db10..688efcb 100644 --- a/nad_ch/core/entities.py +++ b/nad_ch/core/entities.py @@ -174,17 +174,27 @@ def has_report(self) -> bool: class User(Entity): - def __init__(self, email, login_provider, logout_url, id: int = None): + def __init__( + self, + email, + login_provider, + logout_url, + activated=False, + producer: DataProducer = None, + id: int = None, + ): super().__init__(id) self.email = email self.login_provider = login_provider self.logout_url = logout_url + self.activated = activated + self.producer = producer # The following property definitions and get_id method are required in order for the # Flask-Login library to be able to handle instances of the User domain entity. @property def is_active(self): - return True + return isinstance(self.producer, DataProducer) and self.activated @property def is_authenticated(self): @@ -201,4 +211,8 @@ def get_id(self) -> str: raise NotImplementedError("No `id` attribute - override `get_id`") from None def __repr__(self): - return f"User {self.id}, {self.email})" + return f"User {self.id}, {self.email}, {self.activated}, {self.producer.name})" + + def associate_with_data_producer(self, producer: DataProducer): + self.producer = producer + return self diff --git a/nad_ch/infrastructure/database.py b/nad_ch/infrastructure/database.py index 4e83710..400d80b 100644 --- a/nad_ch/infrastructure/database.py +++ b/nad_ch/infrastructure/database.py @@ -10,6 +10,7 @@ ColumnMapRepository, ) from sqlalchemy import ( + Boolean, Column, Integer, String, @@ -71,6 +72,7 @@ class DataProducerModel(CommonBase): ) column_maps = relationship("ColumnMapModel", back_populates="data_producer") + users = relationship("UserModel", back_populates="data_producer") @staticmethod def from_entity(producer: DataProducer): @@ -136,6 +138,10 @@ class UserModel(UserMixin, CommonBase): email = Column(String) login_provider = Column(String) logout_url = Column(String) + data_producer_id = Column(Integer, ForeignKey("data_producers.id"), nullable=True) + activated = Column(Boolean, nullable=False, default=False) + + data_producer = relationship("DataProducerModel", back_populates="users") @staticmethod def from_entity(user): @@ -144,15 +150,21 @@ def from_entity(user): email=user.email, login_provider=user.login_provider, logout_url=user.logout_url, + data_producer_id=user.producer.id if user.producer else None, + activated=user.activated, ) return model def to_entity(self): + producer = self.data_producer.to_entity() if self.data_producer else None + entity = User( id=self.id, email=self.email, login_provider=self.login_provider, logout_url=self.logout_url, + producer=producer, + activated=self.activated, ) if self.created_at is not None: diff --git a/scripts/seed.py b/scripts/seed.py index 16e3e31..4e47d05 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -29,6 +29,8 @@ def main(): email="test@test.org", login_provider="cloudgov", logout_url=OAUTH2_CONFIG["cloudgov"]["logout_url"], + producer=saved_producer, + activated=True, ) ctx.users.add(new_user) diff --git a/tests/application/use_cases/test_column_maps.py b/tests/application/use_cases/test_column_maps.py index 0852ee8..f073ae6 100644 --- a/tests/application/use_cases/test_column_maps.py +++ b/tests/application/use_cases/test_column_maps.py @@ -7,7 +7,7 @@ update_column_mapping_field, ) from nad_ch.application.view_models import ColumnMapViewModel -from nad_ch.core.entities import ColumnMap, DataProducer, User +from nad_ch.core.entities import DataProducer, User from nad_ch.config import create_app_context @@ -18,8 +18,8 @@ def app_context(): def test_add_column_map_is_valid(app_context): - app_context.producers.add(DataProducer("New Jersey")) - user = app_context.users.add(User("test@test.org", "foo", "bar")) + nj = app_context.producers.add(DataProducer("New Jersey")) + user = app_context.users.add(User("test@test.org", "foo", "bar", True, nj)) mapping = { "Add_Number": "address_number", @@ -55,8 +55,8 @@ def test_add_column_map_is_invalid(app_context): def test_get_column_map(app_context): - app_context.producers.add(DataProducer("New Jersey")) - user = app_context.users.add(User("test@test.org", "foo", "bar")) + nj = app_context.producers.add(DataProducer("New Jersey")) + user = app_context.users.add(User("test@test.org", "foo", "bar", True, nj)) mapping = { "Add_Number": "address_number", @@ -87,8 +87,8 @@ def test_get_column_map(app_context): def test_get_column_maps_by_producer(app_context): - app_context.producers.add(DataProducer("New Jersey")) - user = app_context.users.add(User("test@test.org", "foo", "bar")) + nj = app_context.producers.add(DataProducer("New Jersey")) + user = app_context.users.add(User("test@test.org", "foo", "bar", True, nj)) mapping = { "Add_Number": "address_number", @@ -121,8 +121,8 @@ def test_get_column_maps_by_producer(app_context): def test_update_column_mapping(app_context): - app_context.producers.add(DataProducer("New Jersey")) - user = app_context.users.add(User("test@test.org", "foo", "bar")) + nj = app_context.producers.add(DataProducer("New Jersey")) + user = app_context.users.add(User("test@test.org", "foo", "bar", True, nj)) mapping = { "Add_Number": "address_number", @@ -175,8 +175,8 @@ def test_update_column_mapping(app_context): def test_update_column_mapping_field(app_context): - app_context.producers.add(DataProducer("New Jersey")) - user = app_context.users.add(User("test@test.org", "foo", "bar")) + nj = app_context.producers.add(DataProducer("New Jersey")) + user = app_context.users.add(User("test@test.org", "foo", "bar", True, nj)) mapping = { "Add_Number": "address_number", diff --git a/tests/controllers/web/test_column_maps.py b/tests/controllers/web/test_column_maps.py index baf5699..9d25075 100644 --- a/tests/controllers/web/test_column_maps.py +++ b/tests/controllers/web/test_column_maps.py @@ -20,30 +20,33 @@ def client(app): @pytest.fixture def logged_in_client(client, app): with app.app_context(), app.test_request_context(): + producer = DataProducer("New Jersey") + saved_producer = app.extensions["ctx"]["producers"].add(producer) + user = User( - "test_user", "test_user@test.org", "test_provider", "test_logout_url" + "test_user@test.org", + "test_provider", + "test_logout_url", + True, + saved_producer, ) saved_user = app.extensions["ctx"]["users"].add(user) login_user(saved_user) - + print(saved_user) yield client logout_user() def test_column_maps_route_empty(logged_in_client): - logged_in_client.application.extensions["ctx"]["producers"].add( - DataProducer("New Jersey") - ) - response = logged_in_client.get("/column-maps") assert response.status_code == 200 assert "Create Your First Mapping" in response.data.decode("utf-8") def test_column_maps_route_has_two_column_maps(logged_in_client): - nj = logged_in_client.application.extensions["ctx"]["producers"].add( - DataProducer("New Jersey") + nj = logged_in_client.application.extensions["ctx"]["producers"].get_by_name( + "New Jersey" ) cm = ColumnMap( "Test", diff --git a/tests/controllers/web/test_data_submissions.py b/tests/controllers/web/test_data_submissions.py index 07b9ae0..7907a10 100644 --- a/tests/controllers/web/test_data_submissions.py +++ b/tests/controllers/web/test_data_submissions.py @@ -2,7 +2,7 @@ import pytest from nad_ch.config import create_app_context from nad_ch.controllers.web.flask import create_flask_application -from nad_ch.core.entities import User +from nad_ch.core.entities import DataProducer, User @pytest.fixture @@ -20,8 +20,15 @@ def client(app): @pytest.fixture def logged_in_client(client, app): with app.app_context(), app.test_request_context(): + producer = DataProducer("New Jersey") + saved_producer = app.extensions["ctx"]["producers"].add(producer) + user = User( - "test_user", "test_user@test.org", "test_provider", "test_logout_url" + "test_user@test.org", + "test_provider", + "test_logout_url", + True, + saved_producer, ) saved_user = app.extensions["ctx"]["users"].add(user) login_user(saved_user)