diff --git a/README.md b/README.md index 2d51fae..d0b250e 100644 --- a/README.md +++ b/README.md @@ -358,6 +358,7 @@ Os seguintes Status de Pedidos estão disponíveis: | `CONCLUIDO` | O Pedido foi concluído | | `CANCELADO` | O Pedido foi cancelado | + Para executar o projeto localmente utilizando Docker, siga as seguintes etapas: 1. Crie a infraestrutura kubernetes utilizando @@ -389,3 +390,4 @@ aws eks update-kubeconfig --region us-east-1 --name ambrosia-serve-cluster ```shell kubectl get svc ``` + diff --git a/docs/db_model.png b/docs/db_model.png new file mode 100644 index 0000000..5f2af04 Binary files /dev/null and b/docs/db_model.png differ diff --git a/src/cart/adapters/AuditMixin.py b/src/cart/adapters/AuditMixin.py new file mode 100644 index 0000000..70a0931 --- /dev/null +++ b/src/cart/adapters/AuditMixin.py @@ -0,0 +1,13 @@ +from datetime import datetime + +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.sql.functions import now + + +class AuditMixin: + created_at: Mapped[datetime] = mapped_column(default=now()) + updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) + + def __init__(self, created_at: datetime, updated_at: datetime): + self.created_at = created_at + self.updated_at = updated_at diff --git a/src/cart/adapters/order_table.py b/src/cart/adapters/order_table.py index 02d64cd..ac40c82 100644 --- a/src/cart/adapters/order_table.py +++ b/src/cart/adapters/order_table.py @@ -1,25 +1,75 @@ -from datetime import datetime -from sqlalchemy import String, ForeignKey +from sqlalchemy import ForeignKey from sqlalchemy.orm import relationship, Mapped, mapped_column, DeclarativeBase -from sqlalchemy.sql.functions import now + +from src.cart.adapters.AuditMixin import AuditMixin class Base(DeclarativeBase): pass -class OrderTable(Base): +class StatusTable(Base, AuditMixin): + __tablename__ = "status" + + id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) + status: Mapped[str] = mapped_column(nullable=False, unique=True) + + orders: Mapped[list["OrderTable"]] = relationship(back_populates="status", + cascade="all, delete", + passive_deletes=True) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.id = kwargs.get('id') + self.status = kwargs.get('status') + self.created_at = kwargs.get('created_at') + self.updated_at = kwargs.get('updated_at') + + def __repr__(self): + return (f"StatusTable(id={self.id}, " + f"status={self.status}, " + f"created_at={self.created_at}, " + f"updated_at={self.updated_at})") + + +class PaymentConditionTable(Base, AuditMixin): + __tablename__ = "payment_conditions" + + id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) + description: Mapped[str] + + order: Mapped["OrderTable"] = relationship(back_populates="payment_condition", + cascade="all, delete", + passive_deletes=True) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.id = kwargs.get('id') + self.description = kwargs.get('description') + self.created_at = kwargs.get('created_at') + self.updated_at = kwargs.get('updated_at') + + def __repr__(self): + return (f"PaymentConditionTable(id={self.id}, " + f"description={self.description}, " + f"created_at={self.created_at}, " + f"updated_at={self.updated_at})") + + +class OrderTable(Base, AuditMixin): __tablename__ = "orders" id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False, index=True) user_id: Mapped[str] - status: Mapped[int] - payment_condition: Mapped[str] + status_id: Mapped[str] = mapped_column(ForeignKey("status.id", ondelete="CASCADE")) + status: Mapped["StatusTable"] = relationship(back_populates="orders") + payment_condition_id: Mapped[str] = mapped_column(ForeignKey("payment_conditions.id", ondelete="CASCADE")) + payment_condition: Mapped["PaymentConditionTable"] = relationship(back_populates="order") total: Mapped[float] - products: Mapped["OrderProductTable"] = relationship("OrderProductTable", back_populates="order") - created_at: Mapped[datetime] = mapped_column(default=now()) - updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) + products: Mapped[list["OrderProductTable"]] = relationship(back_populates="order", + cascade="all, delete", + passive_deletes=True) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -39,18 +89,16 @@ def __repr__(self): f"updated_at={self.updated_at})") -class OrderProductTable(Base): +class OrderProductTable(Base, AuditMixin): __tablename__ = "order_products" - id: Mapped[str] = mapped_column(String(255), primary_key=True, nullable=False, autoincrement=False) + id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) product_id: Mapped[str] quantity: Mapped[int] observation: Mapped[str] - created_at: Mapped[datetime] = mapped_column(default=now()) - updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) - order_id: Mapped[str] = mapped_column(ForeignKey("orders.id")) - order = relationship("OrderTable", back_populates="products") + order_id: Mapped[str] = mapped_column(ForeignKey("orders.id", ondelete="CASCADE")) + order: Mapped["OrderTable"] = relationship(back_populates="products") def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/cart/adapters/postgres_gateway.py b/src/cart/adapters/postgres_gateway.py index ae52b6b..27281cc 100644 --- a/src/cart/adapters/postgres_gateway.py +++ b/src/cart/adapters/postgres_gateway.py @@ -3,6 +3,7 @@ from sqlalchemy import Row from src.cart.domain.entities.order import Order +from src.cart.domain.enums.order_status import OrderStatus from src.cart.domain.enums.paymentConditions import PaymentConditions from src.cart.ports.cart_gateway import ICartGateway from src.cart.ports.unit_of_work_interface import ICartUnitOfWork @@ -27,7 +28,8 @@ def get_order_products(self, order_id) -> list[dict]: def get_orders(self) -> list[Order]: with self.uow: orders: Optional[list[Row]] = self.uow.repository.get_all() - return [self.build_order_entity(o) for o in orders] + orders_entity = [self.build_order_entity(o) for o in orders] + return orders_entity def get_order_by_id(self, order_id: str) -> Optional[Order]: with self.uow: @@ -42,7 +44,7 @@ def create_update_order(self, order: Order) -> Order: self.uow.repository.insert_update({ 'id': order.id, 'user_id': order.user, - 'status': order.order_status.value, + 'status': order.order_status.name, 'payment_condition': condition.name, 'total': order.total_order, 'products': [{'id': p.id, @@ -68,5 +70,5 @@ def build_order_entity(order): return Order(_id=order.id, user=order.user_id, order_datetime=order.created_at, - order_status=order.status, + order_status=OrderStatus[order.status], payment_condition=payment_condition.value) diff --git a/src/cart/adapters/postgresql_repository.py b/src/cart/adapters/postgresql_repository.py index dbb5a49..7437548 100644 --- a/src/cart/adapters/postgresql_repository.py +++ b/src/cart/adapters/postgresql_repository.py @@ -1,12 +1,21 @@ +import uuid from typing import Any, Sequence, Optional from sqlalchemy import select, delete, Row from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from src.cart.adapters.order_table import OrderTable, OrderProductTable +from src.cart.adapters.order_table import OrderTable, OrderProductTable, StatusTable, PaymentConditionTable from src.cart.ports.repository_interface import IRepository +ORDER_COLS: tuple = ( + OrderTable.id, + OrderTable.user_id, + StatusTable.status, + PaymentConditionTable.description.label('payment_condition'), + OrderTable.created_at, +) + class PostgreSqlRepository(IRepository): @@ -17,12 +26,10 @@ def __init__(self, session: Session): def get_all(self) -> Optional[Sequence[Row]]: stmt = ( select( - OrderTable.id.label('order_id'), - OrderTable.user_id, - OrderTable.status, - OrderTable.payment_condition, - OrderTable.created_at + *ORDER_COLS ) + .join(StatusTable, OrderTable.status_id == StatusTable.id) + .join(PaymentConditionTable, OrderTable.payment_condition_id == PaymentConditionTable.id) .order_by(OrderTable.status, OrderTable.created_at) ) @@ -36,13 +43,11 @@ def get_all(self) -> Optional[Sequence[Row]]: def filter_by_id(self, order_id: str) -> Optional[Row]: stmt = ( select( - OrderTable.id, - OrderTable.user_id, - OrderTable.status, - OrderTable.payment_condition, - OrderTable.created_at, - OrderTable.updated_at + *ORDER_COLS ) + .join(StatusTable, OrderTable.status_id == StatusTable.id) + .join(PaymentConditionTable, OrderTable.payment_condition_id == PaymentConditionTable.id) + .order_by(OrderTable.status, OrderTable.created_at) .where(OrderTable.id == order_id) ) @@ -71,6 +76,14 @@ def _upsert_order(self, values: dict[str, Any]): """Inserção ou atualização de um pedido.""" order_data = {key: values[key] for key in values if key != "products"} + status = self.create_or_get_status(values) + payment_condition = self.create_or_get_payment_condition(values) + + order_data["status_id"] = status.id + order_data["payment_condition_id"] = payment_condition.id + + del order_data["status"] + del order_data["payment_condition"] stmt = insert(OrderTable).values(order_data) stmt = stmt.on_conflict_do_update( index_elements=[OrderTable.id], @@ -78,6 +91,23 @@ def _upsert_order(self, values: dict[str, Any]): ) self.session.execute(stmt) + def create_or_get_payment_condition(self, values): + payment_condition = (self.session.query(PaymentConditionTable) + .filter_by(description=values["payment_condition"]).first()) + if not payment_condition: + payment_condition = PaymentConditionTable(id=str(uuid.uuid4()), description=values["payment_condition"]) + self.session.add(payment_condition) + self.session.commit() + return payment_condition + + def create_or_get_status(self, values): + status = self.session.query(StatusTable).filter_by(status=values["status"]).first() + if not status: + status = StatusTable(id=str(uuid.uuid4()), status=values["status"]) + self.session.add(status) + self.session.commit() + return status + def _upsert_order_products(self, order_id: str, products: list[dict[str, Any]]): """Inserção ou atualização dos produtos relacionados a um pedido.""" for product in products: @@ -91,6 +121,8 @@ def _upsert_order_products(self, order_id: str, products: list[dict[str, Any]]): ) self.session.execute(stmt) + self.session.commit() + def delete(self, order_id: str): stmt = delete(OrderTable).where(OrderTable.id == order_id) self.session.execute(stmt) diff --git a/src/cart/adapters/pydantic_presenter.py b/src/cart/adapters/pydantic_presenter.py index cc4c68c..f15a7f1 100644 --- a/src/cart/adapters/pydantic_presenter.py +++ b/src/cart/adapters/pydantic_presenter.py @@ -1,7 +1,6 @@ from src.api.presentation.shared.dtos.order_response_dto import OrderResponseDto, OrderProductResponseDto from src.cart.domain.entities.order import Order -from src.cart.domain.enums.order_status import OrderStatus from src.cart.ports.cart_presenter import ICartPresenter @@ -20,7 +19,7 @@ def formater(order): return OrderResponseDto(id=order.id, user=order.user, total_order=order.total_order, - order_status=OrderStatus(order.order_status), + order_status=order.order_status.name, payment_condition=order.payment_condition, products=[OrderProductResponseDto( product=p.product.id, diff --git a/src/cart/domain/entities/order.py b/src/cart/domain/entities/order.py index 7c9722b..0f10c62 100644 --- a/src/cart/domain/entities/order.py +++ b/src/cart/domain/entities/order.py @@ -30,6 +30,9 @@ def __hash__(self): def __eq__(self, other): return self.id == other.id + def __gt__(self, other): + return self.order_status.value > other.order_status.value + @property def id(self): return self._id diff --git a/src/cart/use_cases/create_cart.py b/src/cart/use_cases/create_cart.py index 697acad..280c438 100644 --- a/src/cart/use_cases/create_cart.py +++ b/src/cart/use_cases/create_cart.py @@ -6,7 +6,7 @@ from src.cart.domain.enums.order_status import OrderStatus from src.cart.exceptions import (ClientError, ProductNotFoundError, - OrderExistsError) + OrderExistsError, OrderNotFoundError) from src.cart.ports.cart_gateway import ICartGateway from src.cart.use_cases.get_order_by_id import GetOrderByIdUseCase from src.client.ports.user_gateway import IUserGateway @@ -36,9 +36,11 @@ def execute(self, request_data: Dict): for product in products_required: order.add_product(product) - - if self.get_order_by_id.execute(order.id): - raise OrderExistsError(order=order.id) + try: + if self.get_order_by_id.execute(order.id): + raise OrderExistsError(order=order.id) + except OrderNotFoundError: + pass return self.cart_gateway.create_update_order(order) diff --git a/src/cart/use_cases/update_order_status.py b/src/cart/use_cases/update_order_status.py index a90eff2..13c8ff6 100644 --- a/src/cart/use_cases/update_order_status.py +++ b/src/cart/use_cases/update_order_status.py @@ -1,3 +1,4 @@ +from src.cart.domain.enums.order_status import OrderStatus from src.cart.domain.validators.order_validator import OrderValidator from src.cart.exceptions import OrderNotFoundError from src.cart.ports.cart_gateway import ICartGateway @@ -13,6 +14,6 @@ def execute(self, order_id: str, new_status: str): order = self.get_order_by_id.execute(order_id) if not order: raise OrderNotFoundError(order=order_id) - order.order_status = new_status + order.order_status = OrderStatus[new_status] OrderValidator.validate(order) return self.gateway.create_update_order(order) diff --git a/src/client/adapters/AuditMixin.py b/src/client/adapters/AuditMixin.py new file mode 100644 index 0000000..70a0931 --- /dev/null +++ b/src/client/adapters/AuditMixin.py @@ -0,0 +1,13 @@ +from datetime import datetime + +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.sql.functions import now + + +class AuditMixin: + created_at: Mapped[datetime] = mapped_column(default=now()) + updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) + + def __init__(self, created_at: datetime, updated_at: datetime): + self.created_at = created_at + self.updated_at = updated_at diff --git a/src/client/adapters/client_table.py b/src/client/adapters/client_table.py index 0b1b70b..081db21 100644 --- a/src/client/adapters/client_table.py +++ b/src/client/adapters/client_table.py @@ -1,24 +1,50 @@ -from datetime import datetime +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship -from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase -from sqlalchemy.sql.functions import now +from src.client.adapters.AuditMixin import AuditMixin class Base(DeclarativeBase): pass -class ClientTable(Base): - __tablename__ = "clients" +class CredentialTable(Base, AuditMixin): + __tablename__ = "credentials" + + id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) + email: Mapped[str] + password: Mapped[str] + + profile: Mapped["ProfileTable"] = relationship(back_populates="credential", + cascade="all, delete", + passive_deletes=True) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.id = kwargs.get('id') + self.email = kwargs.get('email') + self.password = kwargs.get('password') + self.created_at = kwargs.get('created_at') + self.updated_at = kwargs.get('updated_at') + + def __repr__(self): + return (f"CredentialTable(id={self.id}, " + f"email={self.email}, " + f"password={self.password}, " + f"created_at={self.created_at}, " + f"updated_at={self.updated_at})") + + +class ProfileTable(Base, AuditMixin): + __tablename__ = "profiles" id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) first_name: Mapped[str] last_name: Mapped[str] cpf: Mapped[str] - email: Mapped[str] - password: Mapped[str] - created_at: Mapped[datetime] = mapped_column(default=now()) - updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) + + credential_id: Mapped[str] = mapped_column(ForeignKey("credentials.id", ondelete="CASCADE")) + credential: Mapped["CredentialTable"] = relationship(back_populates="profile") def __init__(self, **kwargs): super().__init__(**kwargs) @@ -26,16 +52,13 @@ def __init__(self, **kwargs): self.first_name = kwargs.get('first_name') self.last_name = kwargs.get('last_name') self.cpf = kwargs.get('cpf') - self.email = kwargs.get('email') - self.password = kwargs.get('password') self.created_at = kwargs.get('created_at') self.updated_at = kwargs.get('updated_at') def __repr__(self): - return (f"ClientTable(id={self.id}, " + return (f"ProfileTable(id={self.id}, " f"first_name={self.first_name}, " f"last_name={self.last_name}, " f"cpf={self.cpf}, " - f"email={self.email}, " f"created_at={self.created_at}, " f"updated_at={self.updated_at})") diff --git a/src/client/adapters/postgresql_repository.py b/src/client/adapters/postgresql_repository.py index 0d3fb3c..490af93 100644 --- a/src/client/adapters/postgresql_repository.py +++ b/src/client/adapters/postgresql_repository.py @@ -1,12 +1,22 @@ +import uuid from typing import Dict, Sequence, Optional from sqlalchemy import select, Row, delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from src.client.adapters.client_table import ClientTable +from src.client.adapters.client_table import ProfileTable, CredentialTable from src.client.ports.repository_interface import IClientRepository +PROFILE_COLS: tuple = ( + ProfileTable.id, + ProfileTable.first_name, + ProfileTable.last_name, + ProfileTable.cpf, + CredentialTable.email, + CredentialTable.password, +) + class PostgreSqlClientRepository(IClientRepository): def __init__(self, session: Session): @@ -14,12 +24,7 @@ def __init__(self, session: Session): self.session = session def get_users(self) -> Optional[Sequence[Row[tuple]]]: - stmt = select(ClientTable.id, - ClientTable.first_name, - ClientTable.last_name, - ClientTable.cpf, - ClientTable.email, - ClientTable.password) + stmt = select(*PROFILE_COLS).join(ProfileTable.credential) results = self.session.execute(stmt).all() if not results: return [] @@ -27,12 +32,7 @@ def get_users(self) -> Optional[Sequence[Row[tuple]]]: return results def get_user_by_cpf(self, cpf: str) -> Optional[Row]: - stmt = select(ClientTable.id, - ClientTable.first_name, - ClientTable.last_name, - ClientTable.cpf, - ClientTable.email, - ClientTable.password).where(ClientTable.cpf == cpf) + stmt = select(*PROFILE_COLS).join(ProfileTable.credential).where(ProfileTable.cpf == cpf) results = self.session.execute(stmt).first() if not results: return @@ -40,12 +40,7 @@ def get_user_by_cpf(self, cpf: str) -> Optional[Row]: return results def get_user_by_email(self, email: str) -> Optional[Row]: - stmt = select(ClientTable.id, - ClientTable.first_name, - ClientTable.last_name, - ClientTable.cpf, - ClientTable.email, - ClientTable.password).where(ClientTable.email == email) + stmt = select(*PROFILE_COLS).join(ProfileTable.credential).where(CredentialTable.email == email) results = self.session.execute(stmt).first() if not results: return @@ -53,28 +48,50 @@ def get_user_by_email(self, email: str) -> Optional[Row]: return results def create_user(self, user: Dict): - stmt = insert(ClientTable).values(**user) - self.session.execute(stmt) + credential = self.create_or_get_credential(user_id=user['id'], email=user['email'], password=user['password']) + del user['email'] + del user['password'] + profile = ProfileTable(credential_id=credential.id, **user) + self.session.add(profile) + self.session.commit() + + def create_or_get_credential(self, user_id: str, email: str, password: str): + credential = self.session.query(CredentialTable).join(CredentialTable.profile).filter( + ProfileTable.id == user_id).first() + if not credential: + credential = CredentialTable(id=str(uuid.uuid4()), email=email, password=password) + self.session.add(credential) + self.session.commit() + return credential def update_user(self, user: Dict): - stmt = insert(ClientTable).values(**user) + credential_data = self.create_or_get_credential(user_id=user['id'], email=user['email'], + password=user['password']) + if credential_data: + stmt = insert(CredentialTable).values(id=credential_data.id, email=user['email'], password=user['password']) + stmt = stmt.on_conflict_do_update( + index_elements=[CredentialTable.id], + set_=dict(email=user['email'], password=user['password']) + ) + self.session.execute(stmt) + + self.session.commit() + + profile_data = {key: user[key] for key in user if key not in ['email', 'password']} + stmt = insert(ProfileTable).values(credential_id=credential_data.id, **profile_data) stmt = stmt.on_conflict_do_update( - index_elements=[ClientTable.id], - set_={key: user[key] for key in user if key != 'id'} + index_elements=[ProfileTable.id], + set_=profile_data ) self.session.execute(stmt) def delete_user(self, user_id: str): - stmt = delete(ClientTable).where(ClientTable.id == user_id) + stmt = delete(ProfileTable).where(ProfileTable.id == user_id) self.session.execute(stmt) + self.session.commit() def get_user_by_id(self, user_id: str) -> Optional[Row]: - stmt = select(ClientTable.id, - ClientTable.first_name, - ClientTable.last_name, - ClientTable.cpf, - ClientTable.email, - ClientTable.password).where(ClientTable.id == user_id) + stmt = select(*PROFILE_COLS).join(ProfileTable.credential).where(ProfileTable.id == user_id) results = self.session.execute(stmt).first() if not results: return diff --git a/src/product/adapters/AuditMixin.py b/src/product/adapters/AuditMixin.py new file mode 100644 index 0000000..70a0931 --- /dev/null +++ b/src/product/adapters/AuditMixin.py @@ -0,0 +1,13 @@ +from datetime import datetime + +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.sql.functions import now + + +class AuditMixin: + created_at: Mapped[datetime] = mapped_column(default=now()) + updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) + + def __init__(self, created_at: datetime, updated_at: datetime): + self.created_at = created_at + self.updated_at = updated_at diff --git a/src/product/adapters/postgres_gateway.py b/src/product/adapters/postgres_gateway.py index d902cd5..4efd829 100644 --- a/src/product/adapters/postgres_gateway.py +++ b/src/product/adapters/postgres_gateway.py @@ -16,7 +16,7 @@ def __init__(self, uow: IProductUnitOfWork): def get_products(self) -> list[Product]: with self.uow: products = self.uow.repository.get_all() - return [self.build_product_entity(p[0]) for p in products] + return [self.build_product_entity(p) for p in products] def get_product_by_id(self, product_id: str) -> Optional[Product]: with self.uow: @@ -48,7 +48,7 @@ def create_update_product(self, product: Product) -> Product: 'name': product.name, 'category': product.category, 'description': product.description, - 'image': product.image, + 'image': product.image, 'stock': product.stock, 'price': product.price}) self.uow.commit() diff --git a/src/product/adapters/postgresql_repository.py b/src/product/adapters/postgresql_repository.py index 918ff10..31c7f3e 100644 --- a/src/product/adapters/postgresql_repository.py +++ b/src/product/adapters/postgresql_repository.py @@ -1,20 +1,32 @@ +import uuid from typing import Any, Sequence, Optional from sqlalchemy import select, Row, delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from src.product.adapters.product_table import ProductTable +from src.product.adapters.product_table import ProductTable, CategoryTable from src.product.ports.repository_interface import IProductRepository +PRODUCTS_COLS: tuple = ( + ProductTable.id, + ProductTable.name, + ProductTable.description, + CategoryTable.category, + ProductTable.image, + ProductTable.price, + ProductTable.stock, +) + class PostgreSqlProductRepository(IProductRepository): + def __init__(self, session: Session): super().__init__() self.session = session def get_all(self) -> Sequence[Row]: - stmt = select(ProductTable) + stmt = select(*PRODUCTS_COLS).join(ProductTable.category) results = self.session.execute(stmt).all() if not results: return [] @@ -22,30 +34,47 @@ def get_all(self) -> Sequence[Row]: return results def find_by_name(self, name: str) -> Optional[Row]: - stmt = select(ProductTable).where(ProductTable.name == name) + stmt = select(*PRODUCTS_COLS).join(ProductTable.category).where(ProductTable.name == name) result = self.session.execute(stmt).first() if not result: return - return result[0] + return result def filter_by_id(self, product_id: str) -> Optional[Row]: - stmt = select(ProductTable).where(ProductTable.id == product_id) + stmt = select(*PRODUCTS_COLS).join(ProductTable.category).where(ProductTable.id == product_id) result = self.session.execute(stmt).first() if not result: return - return result[0] + return result def insert_update(self, values: dict[str, Any]): + category = self.create_or_get_category(values['category']) + + values['category_id'] = category.id + del values['category'] stmt = insert(ProductTable).values(**values) stmt = stmt.on_conflict_do_update( index_elements=[ProductTable.id], set_={key: values[key] for key in values if key != 'id'}, ) self.session.execute(stmt) + self.session.commit() + + def create_or_get_category(self, category: str): + _category = self.session.query(CategoryTable).filter_by(category=category).first() + if not _category: + _category = CategoryTable( + id=str(uuid.uuid4()), + category=category + ) + self.session.add(_category) + self.session.commit() + return _category def delete(self, product_id: str): stmt = delete(ProductTable).where(ProductTable.id == product_id) self.session.execute(stmt) + self.session.commit() diff --git a/src/product/adapters/product_table.py b/src/product/adapters/product_table.py index 266277b..9d13e90 100644 --- a/src/product/adapters/product_table.py +++ b/src/product/adapters/product_table.py @@ -1,15 +1,40 @@ -from datetime import datetime from typing import Any, Optional -from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase -from sqlalchemy.sql.functions import now +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship + +from src.product.adapters.AuditMixin import AuditMixin class Base(DeclarativeBase): pass -class ProductTable(Base): +class CategoryTable(Base, AuditMixin): + __tablename__ = "categories" + + id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) + category: Mapped[str] = mapped_column(nullable=False, unique=True) + + product: Mapped[list["ProductTable"]] = relationship(back_populates="category", + cascade="all, delete", + passive_deletes=True) + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.id = kwargs.get('id') + self.category = kwargs.get('category') + self.created_at = kwargs.get('created_at') + self.updated_at = kwargs.get('updated_at') + + def __repr__(self): + return (f"CategoryTable(id={self.id!r}, " + f"category={self.category!r}, " + f"created_at={self.created_at!r}, " + f"updated_at={self.updated_at!r})") + + +class ProductTable(Base, AuditMixin): __tablename__ = "products" id: Mapped[str] = mapped_column(primary_key=True, nullable=False, autoincrement=False) @@ -17,10 +42,10 @@ class ProductTable(Base): description: Mapped[str] price: Mapped[float] stock: Mapped[int] - category: Mapped[str] image: Mapped[Optional[str]] - created_at: Mapped[datetime] = mapped_column(default=now()) - updated_at: Mapped[datetime] = mapped_column(default=now(), onupdate=now()) + + category_id: Mapped[str] = mapped_column(ForeignKey("categories.id", ondelete="CASCADE")) + category: Mapped["CategoryTable"] = relationship(back_populates="product") def __init__(self, **kwargs: Any): super().__init__(**kwargs) @@ -31,6 +56,7 @@ def __init__(self, **kwargs: Any): self.stock = kwargs.get('stock') self.category = kwargs.get('category') self.image = kwargs.get('image') + self.category_id = kwargs.get('category_id') self.created_at = kwargs.get('created_at') self.updated_at = kwargs.get('updated_at') @@ -42,5 +68,6 @@ def __repr__(self): f"stock={self.stock!r}, " f"category={self.category!r}, " f"image={self.image!r}, " + f"category_id={self.category_id!r}, " f"created_at={self.created_at!r}, " f"updated_at={self.updated_at!r})")