diff --git a/src/koheesio/spark/transformations/download_files.py b/src/koheesio/spark/transformations/download_files.py index 296571c2..3e134f22 100644 --- a/src/koheesio/spark/transformations/download_files.py +++ b/src/koheesio/spark/transformations/download_files.py @@ -146,7 +146,7 @@ def execute(self) -> Output: Download files from URLs in the specified column. """ # Collect the URLs from the DataFrame and process them - source_column_name = get_column_name(self.column) # type: ignore + source_column_name = self.column if isinstance(self.column, str) else get_column_name(self.column) # type: ignore partition = {row.asDict()[source_column_name] for row in self.df.select(self.column).collect()} # type: ignore self.func(partition) diff --git a/tests/spark/transformations/test_download_files.py b/tests/spark/transformations/test_download_files.py index 00f2c2dd..71345334 100644 --- a/tests/spark/transformations/test_download_files.py +++ b/tests/spark/transformations/test_download_files.py @@ -1,10 +1,13 @@ +from pathlib import Path + import pytest -from koheesio.spark.transformations.download_files import DownloadFileFromUrlTransformation +from koheesio.spark import DataFrame, SparkSession # type: ignore +from koheesio.spark.transformations.download_files import DownloadFileFromUrlTransformation # type: ignore @pytest.fixture -def input_df(spark): +def input_df(spark: SparkSession) -> DataFrame: """A simple DataFrame containing two URLs.""" return spark.createDataFrame( [ @@ -16,7 +19,7 @@ def input_df(spark): @pytest.fixture -def download_path(tmp_path): +def download_path(tmp_path: Path) -> Path: _path = tmp_path / "downloads" _path.mkdir(exist_ok=True) return _path @@ -24,18 +27,41 @@ def download_path(tmp_path): class TestDownloadFileFromUrlTransformation: """ + Input DataFrame: + | key | url | |-----|--------------------------------------------| | 101 | http://www.textfiles.com/100/adventur.txt | | 102 | http://www.textfiles.com/100/arttext.fun | + + Output DataFrame: + + | key | url | downloaded_file_path | + |-----|--------------------------------------------|-----------------------| + | 101 | http://www.textfiles.com/100/adventur.txt | downloads/adventur.txt| + | 102 | http://www.textfiles.com/100/arttext.fun | downloads/arttext.fun | + """ - def test_downloading_files(self, input_df, download_path): + def test_downloading_files(self, input_df: DataFrame, download_path: Path) -> None: + """Test that the files are downloaded and the DataFrame is transformed correctly.""" + # Arrange + expected_data = [ + "downloads/adventur.txt", + "downloads/arttext.fun", + ] + + # Act transformed_df = DownloadFileFromUrlTransformation( column="url", download_path=download_path, target_column="downloaded_file_path", ).transform(input_df) + actual_data = sorted( + [row.asDict()["downloaded_file_path"] for row in transformed_df.select("downloaded_file_path").collect()] + ) + + # Assert # Check that adventur.txt and arttext.fun are actually downloaded assert (download_path / "adventur.txt").exists() @@ -43,12 +69,6 @@ def test_downloading_files(self, input_df, download_path): assert transformed_df.count() == 2 assert transformed_df.columns == ["key", "url", "downloaded_file_path"] + # check that the rows of the output DataFrame are as expected - expected_data = [ - "downloads/adventur.txt", - "downloads/arttext.fun", - ] - actual_data = sorted( - [row.asDict()["downloaded_file_path"] for row in transformed_df.select("downloaded_file_path").collect()] - ) assert actual_data == expected_data