diff --git a/pyproject.toml b/pyproject.toml index 8493bfe3..610bd9af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -467,6 +467,7 @@ exclude = [ [tool.ruff.format] # https://docs.astral.sh/ruff/formatter/#docstring-formatting docstring-code-format = true +docstring-code-line-length = 70 [tool.ruff.lint] select = [ diff --git a/src/koheesio/spark/transformations/__init__.py b/src/koheesio/spark/transformations/__init__.py index 3f273a85..ca9866b6 100644 --- a/src/koheesio/spark/transformations/__init__.py +++ b/src/koheesio/spark/transformations/__init__.py @@ -49,19 +49,25 @@ class Transformation(SparkStep, ABC): Example ------- + ### Implementing a transformation using the Transformation class: ```python from koheesio.steps.transformations import Transformation from pyspark.sql import functions as f class AddOne(Transformation): + target_column: str = "new_column" + def execute(self): - self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) + self.output.df = self.df.withColumn( + self.target_column, f.col("old_column") + 1 + ) ``` In the example above, the `execute` method is implemented to add 1 to the values of the `old_column` and store the result in a new column called `new_column`. + ### Using the transformation: In order to use this transformation, we can call the `transform` method: ```python @@ -85,6 +91,7 @@ def execute(self): | 2| 3| ... + ### Alternative ways to use the transformation: Alternatively, we can pass the DataFrame to the constructor and call the `execute` or `transform` method without any arguments: @@ -94,9 +101,24 @@ def execute(self): output_df = AddOne(df).execute().output.df ``` - Note: that the transform method was not implemented explicitly in the AddOne class. This is because the `transform` + > Note: that the transform method was not implemented explicitly in the AddOne class. This is because the `transform` method is already implemented in the `Transformation` class. This means that all classes that inherit from the Transformation class will have the `transform` method available. Only the execute method needs to be implemented. + + ### Using the transformation as a function: + The transformation can also be used as a function as part of a DataFrame's `transform` method: + + ```python + input_df = spark.range(3) + + output_df = input_df.transform(AddOne(target_column="foo")).transform( + AddOne(target_column="bar") + ) + ``` + + In the above example, the `AddOne` transformation is applied to the `input_df` DataFrame using the `transform` + method. The `output_df` will now contain the original DataFrame with an additional columns called `foo` and + `bar', each with the values of `id` + 1. """ df: Optional[DataFrame] = Field(default=None, description="The Spark DataFrame") @@ -111,7 +133,9 @@ def execute(self) -> SparkStep.Output: For example: ```python def execute(self): - self.output.df = self.df.withColumn("new_column", f.col("old_column") + 1) + self.output.df = self.df.withColumn( + "new_column", f.col("old_column") + 1 + ) ``` The transform method will call this method and return the output DataFrame. @@ -147,6 +171,26 @@ def transform(self, df: Optional[DataFrame] = None) -> DataFrame: self.execute() return self.output.df + def __call__(self, *args, **kwargs): + """Allow the class to be called as a function. + This is especially useful when using a DataFrame's transform method. + + Example + ------- + ```python + input_df = spark.range(3) + + output_df = input_df.transform(AddOne(target_column="foo")).transform( + AddOne(target_column="bar") + ) + ``` + + In the above example, the `AddOne` transformation is applied to the `input_df` DataFrame using the `transform` + method. The `output_df` will now contain the original DataFrame with an additional columns called `foo` and + `bar', each with the values of `id` + 1. + """ + return self.transform(*args, **kwargs) + class ColumnsTransformation(Transformation, ABC): """Extended Transformation class with a preset validator for handling column(s) data with a standardized input @@ -204,7 +248,9 @@ class ColumnsTransformation(Transformation, ABC): class AddOne(ColumnsTransformation): def execute(self): for column in self.get_columns(): - self.output.df = self.df.withColumn(column, f.col(column) + 1) + self.output.df = self.df.withColumn( + column, f.col(column) + 1 + ) ``` In the above example, the `execute` method is implemented to add 1 to the values of a given column. @@ -460,7 +506,9 @@ class ColumnsTransformationWithTarget(ColumnsTransformation, ABC): ```python from pyspark.sql import Column - from koheesio.steps.transformations import ColumnsTransformationWithTarget + from koheesio.steps.transformations import ( + ColumnsTransformationWithTarget, + ) class AddOneWithTarget(ColumnsTransformationWithTarget): @@ -477,7 +525,9 @@ def func(self, col: Column): # create a DataFrame with 3 rows df = SparkSession.builder.getOrCreate().range(3) - output_df = AddOneWithTarget(column="id", target_column="new_id").transform(df) + output_df = AddOneWithTarget( + column="id", target_column="new_id" + ).transform(df) ``` The `output_df` will now contain the original DataFrame with an additional column called `new_id` with the values of diff --git a/tests/spark/transformations/test_transform.py b/tests/spark/transformations/test_transform.py index 1f92e490..ce46a3d5 100644 --- a/tests/spark/transformations/test_transform.py +++ b/tests/spark/transformations/test_transform.py @@ -6,7 +6,9 @@ from koheesio.logger import LoggingFactory from koheesio.spark import DataFrame +from koheesio.spark.transformations.strings.substring import Substring from koheesio.spark.transformations.transform import Transform +from koheesio.spark.transformations.hash import Sha2Hash pytestmark = pytest.mark.spark @@ -83,3 +85,29 @@ def test_from_func(dummy_df): AddFooColumn = Transform.from_func(dummy_transform_func, target_column="foo") df = AddFooColumn(value="bar").transform(dummy_df) assert transform_output_test(df, {"id": 0, "foo": "bar"}) + + +def test_df_transform_compatibility(dummy_df: DataFrame): + + expected_data = { + "id": 0, + "foo": "bar", + "bar": "baz", + "foo_hash": "fcde2b2edba56bf408601fb721fe9b5c338d10ee429ea04fae5511b68fbf8fb9", + "foo_sub": "fcde2b", + } + + # set up a reusable Transform from a function + add_column = Transform.from_func(dummy_transform_func, value="bar") + + output_df = ( + dummy_df + # test the Transform class with multiple chained transforms + .transform(add_column(target_column="foo")) + .transform(add_column(target_column="bar", value="baz")) + # test that Transformation classes can be called directly by DataFrame.transform + .transform(Sha2Hash(columns="foo", target_column="foo_hash")) + .transform(Substring(column="foo_hash", start=1, length=6, target_column="foo_sub")) + ) + + assert output_df.head().asDict() == expected_data