Skip to content

Commit

Permalink
refactor: improve type handling in DownloadFileFromUrlTransformation …
Browse files Browse the repository at this point in the history
…and enhance test type hints
  • Loading branch information
dannymeijer committed Dec 13, 2024
1 parent c18cf66 commit 3544a93
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/koheesio/spark/transformations/download_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def execute(self) -> Output:
Download files from URLs in the specified column.
"""
# Collect the URLs from the DataFrame and process them
source_column_name = get_column_name(self.column) # type: ignore
source_column_name = self.column if isinstance(self.column, str) else get_column_name(self.column) # type: ignore
partition = {row.asDict()[source_column_name] for row in self.df.select(self.column).collect()} # type: ignore
self.func(partition)

Expand Down
42 changes: 31 additions & 11 deletions tests/spark/transformations/test_download_files.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path

import pytest

from koheesio.spark.transformations.download_files import DownloadFileFromUrlTransformation
from koheesio.spark import DataFrame, SparkSession # type: ignore
from koheesio.spark.transformations.download_files import DownloadFileFromUrlTransformation # type: ignore


@pytest.fixture
def input_df(spark):
def input_df(spark: SparkSession) -> DataFrame:
"""A simple DataFrame containing two URLs."""
return spark.createDataFrame(
[
Expand All @@ -16,39 +19,56 @@ def input_df(spark):


@pytest.fixture
def download_path(tmp_path):
def download_path(tmp_path: Path) -> Path:
_path = tmp_path / "downloads"
_path.mkdir(exist_ok=True)
return _path


class TestDownloadFileFromUrlTransformation:
"""
Input DataFrame:
| key | url |
|-----|--------------------------------------------|
| 101 | http://www.textfiles.com/100/adventur.txt |
| 102 | http://www.textfiles.com/100/arttext.fun |
Output DataFrame:
| key | url | downloaded_file_path |
|-----|--------------------------------------------|-----------------------|
| 101 | http://www.textfiles.com/100/adventur.txt | downloads/adventur.txt|
| 102 | http://www.textfiles.com/100/arttext.fun | downloads/arttext.fun |
"""

def test_downloading_files(self, input_df, download_path):
def test_downloading_files(self, input_df: DataFrame, download_path: Path) -> None:
"""Test that the files are downloaded and the DataFrame is transformed correctly."""
# Arrange
expected_data = [
"downloads/adventur.txt",
"downloads/arttext.fun",
]

# Act
transformed_df = DownloadFileFromUrlTransformation(
column="url",
download_path=download_path,
target_column="downloaded_file_path",
).transform(input_df)
actual_data = sorted(
[row.asDict()["downloaded_file_path"] for row in transformed_df.select("downloaded_file_path").collect()]
)

# Assert

# Check that adventur.txt and arttext.fun are actually downloaded
assert (download_path / "adventur.txt").exists()
assert (download_path / "arttext.fun").exists()

assert transformed_df.count() == 2
assert transformed_df.columns == ["key", "url", "downloaded_file_path"]

# check that the rows of the output DataFrame are as expected
expected_data = [
"downloads/adventur.txt",
"downloads/arttext.fun",
]
actual_data = sorted(
[row.asDict()["downloaded_file_path"] for row in transformed_df.select("downloaded_file_path").collect()]
)
assert actual_data == expected_data

0 comments on commit 3544a93

Please sign in to comment.