diff --git a/nad_ch/domain/entities.py b/nad_ch/domain/entities.py index 95d86ad..48f77a6 100644 --- a/nad_ch/domain/entities.py +++ b/nad_ch/domain/entities.py @@ -1,3 +1,8 @@ +import datetime +import os +import re + + class Entity: def __init__(self, id: int = None): self.id = id @@ -24,17 +29,26 @@ def __repr__(self): class DataSubmission(Entity): def __init__( self, - file_name: str, - url: str, + filename: str, provider: DataProvider, id: int = None, ): super().__init__(id) - self.file_name = file_name - self.url = url + self.filename = filename self.provider = provider def __repr__(self): return f"DataSubmission \ - {self.id}, {self.file_name}, {self.url}, {self.provider} \ + {self.id}, {self.filename}, {self.provider} \ (created: {self.created_at}; updated: {self.updated_at})" + + @staticmethod + def generate_filename(file_path: str, provider: DataProvider) -> str: + s = re.sub(r'\W+', '_', provider.name) + s = s.lower() + s = s.strip('_') + formatted_provider_name = re.sub(r'_+', '_', s) + datetime_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + _, file_extension = os.path.splitext(file_path) + filename = f"{formatted_provider_name}_{datetime_str}{file_extension}" + return filename diff --git a/nad_ch/infrastructure/database.py b/nad_ch/infrastructure/database.py index badb342..07d5cfc 100644 --- a/nad_ch/infrastructure/database.py +++ b/nad_ch/infrastructure/database.py @@ -67,8 +67,7 @@ def to_entity(self): class DataSubmissionModel(CommonBase): __tablename__ = "data_submissions" - file_name = Column(String) - url = Column(String) + filename = Column(String) data_provider_id = Column(Integer, ForeignKey("data_providers.id")) data_provider = relationship("DataProviderModel", back_populates="data_submissions") @@ -77,15 +76,14 @@ class DataSubmissionModel(CommonBase): def from_entity(submission): model = DataSubmissionModel( id=submission.id, - file_name=submission.file_name, - url=submission.url, + filename=submission.filename, data_provider_id=submission.provider.id, ) return model def to_entity(self, provider: DataProvider): entity = DataSubmission( - id=self.id, file_name=self.file_name, url=self.url, provider=provider + id=self.id, filename=self.filename, provider=provider ) if self.created_at is not None: @@ -142,7 +140,7 @@ def add(self, submission: DataSubmission) -> DataSubmission: ) return submission_model.to_entity(provider_model.to_entity()) - def get_by_name(self, file_name: str) -> Optional[DataSubmission]: + def get_by_id(self, id: int) -> Optional[DataSubmission]: with self.session_factory() as session: result = ( session.query(DataSubmissionModel, DataProviderModel) @@ -150,7 +148,7 @@ def get_by_name(self, file_name: str) -> Optional[DataSubmission]: DataProviderModel, DataProviderModel.id == DataSubmissionModel.data_provider_id, ) - .filter(DataSubmissionModel.file_name == file_name) + .filter(DataSubmissionModel.id == id) .first() ) diff --git a/nad_ch/infrastructure/storage.py b/nad_ch/infrastructure/storage.py index 4b83e68..8de9c07 100644 --- a/nad_ch/infrastructure/storage.py +++ b/nad_ch/infrastructure/storage.py @@ -17,5 +17,5 @@ def delete(self, file_path: str) -> None: if os.path.exists(full_file_path): os.remove(full_file_path) - def get_file_url(self, file_name: str) -> str: - return file_name + def get_file_url(self, filename: str) -> str: + return filename diff --git a/nad_ch/use_cases.py b/nad_ch/use_cases.py index d1e175e..56ff854 100644 --- a/nad_ch/use_cases.py +++ b/nad_ch/use_cases.py @@ -1,3 +1,4 @@ +import os from typing import List from nad_ch.application_context import ApplicationContext from nad_ch.domain.entities import DataProvider, DataSubmission @@ -34,16 +35,21 @@ def ingest_data_submission( ctx.logger.error("File path required") return + _, file_extension = os.path.splitext(file_path) + if file_extension.lower() not in ['.zip', '.csv']: + ctx.logger.error("Invalid file format. Only ZIP or CSV files are accepted.") + return + provider = ctx.providers.get_by_name(provider_name) if not provider: ctx.logger.error("Provider with that name does not exist") return try: - ctx.storage.upload(file_path, f"{provider.name}_{file_path}") - url = ctx.storage.get_file_url(file_path) + filename = DataSubmission.generate_filename(file_path, provider) + ctx.storage.upload(file_path, filename) - submission = DataSubmission(file_path, url, provider) + submission = DataSubmission(filename, provider) ctx.submissions.add(submission) ctx.logger.info("Submission added") except Exception as e: @@ -62,6 +68,6 @@ def list_data_submissions_by_provider( submissions = ctx.submissions.get_by_provider(provider) ctx.logger.info(f"Data submissions for {provider.name}") for s in submissions: - ctx.logger.info(f"{s.provider.name}: {s.file_name}") + ctx.logger.info(f"{s.provider.name}: {s.filename}") return submissions diff --git a/tests/domain/test_entities.py b/tests/domain/test_entities.py new file mode 100644 index 0000000..0e530c8 --- /dev/null +++ b/tests/domain/test_entities.py @@ -0,0 +1,12 @@ +import datetime +from nad_ch.domain.entities import DataProvider, DataSubmission + + +def test_data_submission_generates_filename(): + provider = DataProvider("Some Provider") + filename = DataSubmission.generate_filename("someupload.zip", provider) + todays_date = datetime.datetime.now().strftime("%Y%m%d") + print(filename) + assert filename.startswith("some_provider") + assert todays_date in filename + assert filename.endswith(".zip") diff --git a/tests/fakes.py b/tests/fakes.py index 2ec7ead..14352cf 100644 --- a/tests/fakes.py +++ b/tests/fakes.py @@ -32,8 +32,8 @@ def add(self, submission: DataSubmission) -> DataSubmission: self._next_id += 1 return submission - def get_by_name(self, file_name: str) -> Optional[DataSubmission]: - return next((s for s in self._submissions if s.file_name == file_name), None) + def get_by_id(self, id: int) -> Optional[DataSubmission]: + return next((s for s in self._submissions if s.id == id), None) def get_by_provider(self, provider: DataProvider) -> Optional[DataSubmission]: return [s for s in self._submissions if s.provider.name == provider.name] @@ -46,5 +46,5 @@ def __init__(self): def upload(self, source: str, destination: str) -> None: self._files.add(destination) - def get_file_url(self, file_name: str) -> str: - return file_name + def get_file_url(self, filename: str) -> str: + return filename diff --git a/tests/infrastructure/test_database.py b/tests/infrastructure/test_database.py index d4e1b00..975c2a8 100644 --- a/tests/infrastructure/test_database.py +++ b/tests/infrastructure/test_database.py @@ -60,7 +60,7 @@ def test_add_data_provider_and_then_data_submission(providers, submissions): provider_name = "State X" new_provider = DataProvider(provider_name) saved_provider = providers.add(new_provider) - new_submission = DataSubmission("some-file-name", "some-url", saved_provider) + new_submission = DataSubmission("some-file-name", saved_provider) result = submissions.add(new_submission) @@ -68,18 +68,17 @@ def test_add_data_provider_and_then_data_submission(providers, submissions): assert result.created_at is not None assert result.updated_at is not None assert result.provider.id == saved_provider.id - assert result.file_name == "some-file-name" - assert result.url == "some-url" + assert result.filename == "some-file-name" def test_retrieve_a_list_of_submissions_by_provider(providers, submissions): provider_name = "State X" new_provider = DataProvider(provider_name) saved_provider = providers.add(new_provider) - new_submission = DataSubmission("some-file-name", "some-url", saved_provider) + new_submission = DataSubmission("some-file-name", saved_provider) submissions.add(new_submission) another_new_submission = DataSubmission( - "some-other-file-name", "some-other-url", saved_provider + "some-other-file-name", saved_provider ) submissions.add(another_new_submission) diff --git a/tests/test_use_cases.py b/tests/test_use_cases.py index 1278005..4423019 100644 --- a/tests/test_use_cases.py +++ b/tests/test_use_cases.py @@ -64,11 +64,10 @@ def test_ingest_data_submission(app_context): provider_name = "State X" add_data_provider(app_context, provider_name) - file_name = "my_cool_file.txt" - ingest_data_submission(app_context, file_name, provider_name) + filename = "my_cool_file.zip" + ingest_data_submission(app_context, filename, provider_name) - submission = app_context.submissions.get_by_name(file_name) - assert submission.file_name == file_name + submission = app_context.submissions.get_by_id(1) assert isinstance(submission, DataSubmission) is True @@ -76,8 +75,8 @@ def test_list_data_submissions_by_provider(app_context): provider_name = "State X" add_data_provider(app_context, provider_name) - file_name = "my_cool_file.txt" - ingest_data_submission(app_context, file_name, provider_name) + filename = "my_cool_file.zip" + ingest_data_submission(app_context, filename, provider_name) provider = app_context.providers.get_by_name(provider_name) submissions = app_context.submissions.get_by_provider(provider)