From 1377115027e60332f989b70f4f876b2f34a9c29a Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Mon, 3 Jun 2024 17:29:54 +0200 Subject: [PATCH 01/17] Ability to read Excel files - introduces pandas module - abstracted Reader class to a common BaseReader (as part of models module) - re-implemented spark.readers.Reader based on BaseReader - introduces ExcelReader at koheesio.pandas.readers.excel.ExcelReader - introduces ExcelReader at koheesio.spark.reader.excel.ExcelReader - added unittests to cover the above - added excel extra dependency - added docs --- pyproject.toml | 8 ++- .../spark/dq/spark_expectations.py | 8 +-- src/koheesio/models/reader.py | 50 ++++++++++++++++++ src/koheesio/pandas/__init__.py | 26 +++++++++ src/koheesio/pandas/readers/__init__.py | 34 ++++++++++++ src/koheesio/pandas/readers/excel.py | 50 ++++++++++++++++++ src/koheesio/spark/readers/__init__.py | 24 ++------- src/koheesio/spark/readers/excel.py | 40 ++++++++++++++ src/koheesio/utils.py | 1 - tests/_data/readers/excel_file/dummy.xlsx | Bin 0 -> 10314 bytes tests/conftest.py | 18 +++++++ tests/pandas/readers/test_pandas_excel.py | 28 ++++++++++ tests/spark/conftest.py | 22 +++----- .../dq/test_spark_expectations.py | 6 ++- tests/spark/readers/test_spark_excel.py | 27 ++++++++++ tests/spark/test_delta.py | 5 +- .../transformations/test_sql_transform.py | 2 +- .../spark/writers/delta/test_delta_writer.py | 7 ++- 18 files changed, 306 insertions(+), 50 deletions(-) create mode 100644 src/koheesio/models/reader.py create mode 100644 src/koheesio/pandas/__init__.py create mode 100644 src/koheesio/pandas/readers/__init__.py create mode 100644 src/koheesio/pandas/readers/excel.py create mode 100644 src/koheesio/spark/readers/excel.py create mode 100644 tests/_data/readers/excel_file/dummy.xlsx create mode 100644 tests/pandas/readers/test_pandas_excel.py create mode 100644 tests/spark/readers/test_spark_excel.py diff --git a/pyproject.toml b/pyproject.toml index 6abf886b..3c14c130 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ se = ["spark-expectations>=2.1.0"] # SFTP dependencies in to_csv line_iterator sftp = ["paramiko>=2.6.0"] delta = ["delta-spark>=2.2"] +excel = ["openpyxl>=3.0.0"] dev = ["black", "isort", "ruff", "mypy", "pylint", "colorama", "types-PyYAML"] test = [ "chispa", @@ -175,6 +176,7 @@ features = [ "pyspark", "sftp", "delta", + "excel", "se", "box", "dev", @@ -192,8 +194,8 @@ isort-check = "isort . --check --diff --color" isort-fmt = "isort ." ruff-check = "ruff check ." ruff-fmt = "ruff check . --fix" -mypy-check = "mypy koheesio" -pylint-check = "pylint --output-format=colorized -d W0511 koheesio" +mypy-check = "mypy src" +pylint-check = "pylint --output-format=colorized -d W0511 src" check = [ "- black-check", "- isort-check", @@ -249,6 +251,7 @@ features = [ "pyspark", "sftp", "delta", + "excel", "dev", "test", ] @@ -400,6 +403,7 @@ features = [ "se", "sftp", "delta", + "excel", "dev", "test", "docs", diff --git a/src/koheesio/integrations/spark/dq/spark_expectations.py b/src/koheesio/integrations/spark/dq/spark_expectations.py index 325ccaf5..8766a8e4 100644 --- a/src/koheesio/integrations/spark/dq/spark_expectations.py +++ b/src/koheesio/integrations/spark/dq/spark_expectations.py @@ -4,15 +4,17 @@ from typing import Any, Dict, Optional, Union -import pyspark -from pydantic import Field -from pyspark.sql import DataFrame from spark_expectations.config.user_config import Constants as user_config from spark_expectations.core.expectations import ( SparkExpectations, WrappedDataFrameWriter, ) +from pydantic import Field + +import pyspark +from pyspark.sql import DataFrame + from koheesio.spark.transformations import Transformation from koheesio.spark.writers import BatchOutputMode diff --git a/src/koheesio/models/reader.py b/src/koheesio/models/reader.py new file mode 100644 index 00000000..c1227940 --- /dev/null +++ b/src/koheesio/models/reader.py @@ -0,0 +1,50 @@ +""" +Module for the BaseReader class +""" + +from typing import Optional, TypeVar +from abc import ABC, abstractmethod + +from koheesio import Step + +# Define a type variable that can be any type of DataFrame +DataFrameType = TypeVar("DataFrameType") + + +class BaseReader(Step, ABC): + """Base class for all Readers + + A Reader is a Step that reads data from a source based on the input parameters + and stores the result in self.output.df (DataFrame). + + When implementing a Reader, the execute() method should be implemented. + The execute() method should read from the source and store the result in self.output.df. + + The Reader class implements a standard read() method that calls the execute() method and returns the result. This + method can be used to read data from a Reader without having to call the execute() method directly. Read method + does not need to be implemented in the child class. + + The Reader class also implements a shorthand for accessing the output Dataframe through the df-property. If the + output.df is None, .execute() will be run first. + """ + + @property + def df(self) -> Optional[DataFrameType]: + """Shorthand for accessing self.output.df + If the output.df is None, .execute() will be run first + """ + if not self.output.df: + self.execute() + return self.output.df + + @abstractmethod + def execute(self) -> Step.Output: + """Execute on a Reader should handle self.output.df (output) as a minimum + Read from whichever source -> store result in self.output.df + """ + pass + + def read(self) -> DataFrameType: + """Read from a Reader without having to call the execute() method directly""" + self.execute() + return self.output.df diff --git a/src/koheesio/pandas/__init__.py b/src/koheesio/pandas/__init__.py new file mode 100644 index 00000000..b8494479 --- /dev/null +++ b/src/koheesio/pandas/__init__.py @@ -0,0 +1,26 @@ +"""Base class for a Pandas step + +Extends the Step class with Pandas DataFrame support. The following: +- Pandas steps are expected to return a Pandas DataFrame as output. +""" + +from typing import Optional +from abc import ABC + +from pandas import DataFrame + +from koheesio import Step, StepOutput +from koheesio.models import Field + + +class PandasStep(Step, ABC): + """Base class for a Pandas step + + Extends the Step class with Pandas DataFrame support. The following: + - Pandas steps are expected to return a Pandas DataFrame as output. + """ + + class Output(StepOutput): + """Output class for PandasStep""" + + df: Optional[DataFrame] = Field(default=None, description="The Pandas DataFrame") diff --git a/src/koheesio/pandas/readers/__init__.py b/src/koheesio/pandas/readers/__init__.py new file mode 100644 index 00000000..933561ae --- /dev/null +++ b/src/koheesio/pandas/readers/__init__.py @@ -0,0 +1,34 @@ +""" +Base class for all Readers +""" + +from abc import ABC, abstractmethod + +from koheesio.models.reader import BaseReader +from koheesio.pandas import PandasStep + + +class Reader(BaseReader, PandasStep, ABC): + """Base class for all Readers + + A Reader is a Step that reads data from a source based on the input parameters + and stores the result in self.output.df (DataFrame). + + When implementing a Reader, the execute() method should be implemented. + The execute() method should read from the source and store the result in self.output.df. + + The Reader class implements a standard read() method that calls the execute() method and returns the result. This + method can be used to read data from a Reader without having to call the execute() method directly. Read method + does not need to be implemented in the child class. + + The Reader class also implements a shorthand for accessing the output Dataframe through the df-property. If the + output.df is None, .execute() will be run first. + """ + + @abstractmethod + def execute(self) -> PandasStep.Output: + """Execute on a Reader should handle self.output.df (output) as a minimum + Read from whichever source -> store result in self.output.df + """ + # self.output.df # output dataframe + ... diff --git a/src/koheesio/pandas/readers/excel.py b/src/koheesio/pandas/readers/excel.py new file mode 100644 index 00000000..5432aedf --- /dev/null +++ b/src/koheesio/pandas/readers/excel.py @@ -0,0 +1,50 @@ +""" +Excel reader for Spark + +Note +---- +Ensure the 'excel' extra is installed before using this reader. +Default implementation uses openpyxl as the engine for reading Excel files. +Other implementations can be used by passing the correct keyword arguments to the reader. + +See Also +-------- +- https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html +- koheesio.pandas.readers.excel.ExcelReader +""" + +from typing import List, Optional, Union +from pathlib import Path + +import pandas as pd + +from koheesio.models import ExtraParamsMixin, Field +from koheesio.pandas.readers import Reader + + +class ExcelReader(Reader, ExtraParamsMixin): + """Read data from an Excel file + + See Also + -------- + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html + + Attributes + ---------- + path : Union[str, Path] + The path to the Excel file + sheet_name : str + The name of the sheet to read + header : Optional[Union[int, List[int]]] + Row(s) to use as the column names + + Any other keyword arguments will be passed to pd.read_excel. + """ + + path: Union[str, Path] = Field(description="The path to the Excel file") + sheet_name: str = Field(default="Sheet1", description="The name of the sheet to read") + header: Optional[Union[int, List[int]]] = Field(default=0, description="Row(s) to use as the column names") + + def execute(self): + extra_params = self.params or {} + self.output.df = pd.read_excel(self.path, sheet_name=self.sheet_name, header=self.header, **extra_params) diff --git a/src/koheesio/spark/readers/__init__.py b/src/koheesio/spark/readers/__init__.py index 42d0870d..ab81a7d4 100644 --- a/src/koheesio/spark/readers/__init__.py +++ b/src/koheesio/spark/readers/__init__.py @@ -6,15 +6,13 @@ [reference/concepts/steps/readers](../../../reference/concepts/readers.md) section of the Koheesio documentation. """ -from typing import Optional from abc import ABC, abstractmethod -from pyspark.sql import DataFrame - +from koheesio.models.reader import BaseReader from koheesio.spark import SparkStep -class Reader(SparkStep, ABC): +class Reader(BaseReader, SparkStep, ABC): """Base class for all Readers A Reader is a Step that reads data from a source based on the input parameters @@ -33,24 +31,8 @@ class Reader(SparkStep, ABC): output.df is None, .execute() will be run first. """ - @property - def df(self) -> Optional[DataFrame]: - """Shorthand for accessing self.output.df - If the output.df is None, .execute() will be run first - """ - if not self.output.get("df"): - self.execute() - return self.output.df - @abstractmethod - def execute(self): + def execute(self) -> SparkStep.Output: """Execute on a Reader should handle self.output.df (output) as a minimum Read from whichever source -> store result in self.output.df """ - # self.output.df # output dataframe - ... - - def read(self) -> Optional[DataFrame]: - """Read from a Reader without having to call the execute() method directly""" - self.execute() - return self.output.df diff --git a/src/koheesio/spark/readers/excel.py b/src/koheesio/spark/readers/excel.py new file mode 100644 index 00000000..4b52cc79 --- /dev/null +++ b/src/koheesio/spark/readers/excel.py @@ -0,0 +1,40 @@ +""" +Excel reader for Spark + +Note +---- +Ensure the 'excel' extra is installed before using this reader. +Default implementation uses openpyxl as the engine for reading Excel files. +Other implementations can be used by passing the correct keyword arguments to the reader. + +See Also +-------- +- https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html +- koheesio.pandas.readers.excel.ExcelReader +""" + +from pyspark.pandas import DataFrame as PandasDataFrame + +from koheesio.pandas.readers.excel import ExcelReader as PandasExcelReader +from koheesio.spark.readers import Reader + + +class ExcelReader(Reader, PandasExcelReader): + """Read data from an Excel file + + This class is a wrapper around the PandasExcelReader class. It reads an Excel file first using pandas, and then + converts the pandas DataFrame to a Spark DataFrame. + + Attributes + ---------- + path: str + The path to the Excel file + sheet_name: str + The name of the sheet to read + header: int + The row to use as the column names + """ + + def execute(self): + pdf: PandasDataFrame = PandasExcelReader.from_step(self).execute().df + self.output.df = self.spark.createDataFrame(pdf) diff --git a/src/koheesio/utils.py b/src/koheesio/utils.py index 61bd16af..9547f8c7 100644 --- a/src/koheesio/utils.py +++ b/src/koheesio/utils.py @@ -3,7 +3,6 @@ """ import inspect -import os import uuid from typing import Any, Callable, Dict, Optional, Tuple from functools import partial diff --git a/tests/_data/readers/excel_file/dummy.xlsx b/tests/_data/readers/excel_file/dummy.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..eeef11058c51f46cd0887f3d58bb558572fbaccc GIT binary patch literal 10314 zcmeHtg;!Nu_x7O#1f{#<&>aHODTsuGNQrdkp_@auq=a1J~NiGJOUyv-~j*y002+{7MR0dTfqSUm`DHs4gdvS zThi9X(b&dOPtDEF7^KVQYHdZCjR?<>34n*a|G(qEcm*Kw3QC%hVyUzropiHo5Q)Q$pTFT(*fy&D&S*3r0<34lT(-D?KTu z8nHjZGasn;b^rW=P~^8>s?${-NH{jGAWf}HNznUCdnaTjI)_9Vd&o)`mh16*LzO&P(AYB~FfTVqcX&$ere|c;b|jVB zigV-1s`#N-g_;NZoNMG+4>0q@N2{9W`#_po7awY8aqnRtIGmzz5yn^l1j(U4XM7q0 z0055)0ARpY#?^}5#n!>nz}D9CCyPPUb#06JF}?K6?+{kg=JGv@x7ATHBPZ0>Aos)6o);M*y?N7)_>085AeCK8rcskVM|M zny8wpNb(l}jjy=09D#-(OW#FS1Y7izfTA(`Kpb5bpFoKro~l(cd*Zakh47cUee*VA z#YdoJ3#RGRseb7N_%52VdiQN{g`DA>ZB-gtF9_u>nSEqhKnC(1V&_{7dcW=U!jw%F zM}s^z0P&t&Mwb!?^}6Ik=K=wKj6U&LVUTVo4fD2FY=TSIkHt2WTXP=x@ozQQ;y*+} zH^>`iH%=zQGATp+Ciuh))&O|1S=ArjC}P{bLeCyPQ|w)QsBz66X@R! z?QRkqO9{Af|$qU)cCd__)#Ji zZH-oDhDmUlXHm_tnNF${;6PP*g{`yhGA&bfjHr62M6z#=qLcmGjCjn>6y*@o3;G^S z958zYS%6K#B&(FNblQLLc6)CKI-6k$IiD{oW2R+C_C(ipWVY7hbL<=9Cgeq213@lb zY%w18&T%?&mOAKnKZ1{lk>Zp+&?Q12CO^*aWKdM$oaD$s4Rp9zS|6crQ_hi}PvH+uYxGXZttdvz z6QvltvkyI0z<#s@9IDJ!VWl0KKIBcUQhy=%ZbF1|`6}IN_W@@MVwj=~*8mk&?dlAO zg|ClIH-k2AVra4V+o48Aa|irH>rmmGn8(iW{&Wm^Rp-7vsAh>DcUSUE`bcW?=J2#t zz&mh@**PEN?32H??p&Lcw(SwemWVOiFubPuA!6i=Vq|ko6o31m6KjhC=b;NJu1+-c z#TaGWPV2uS*-s??{SgeojWD_+2B5&fko*VZ{T0vuVLv!nXbRKv-+e$-6y&=&FrjF- z!5l8B&Ujd}j_lOC>N^-HebtQfbX45l$4ewkjk>R=Wt!iS1=c$K~HRX8;lf=j2mM z9XE^z@mLIVIB{8kM#=hN-hZHa-S;v;b+-guT~|k$!f5jKu0eAC^g#yLjbiMG2vtt# zksDxaYl4dTyj}VfS{fR37JW%Nt3LZb6N7^G0&4#S3@l) z5?T^9yc;A}(uzTNO|K9)W`?N(AIXttRZDbpVTkjB!9JZ2sFCj%A>yACeKe9`^32es z^uZw_YUF^~o>wZ@hSZ~!CO`A6cJ_pd6c5iD?F}uo=H0kz1AD@#b~#`8q$2MrkRkgyRw6cx=}30&7DScHC31Fl@Ho zzE!+6;pX1s?<7FhK;B^()f+jNUew+&rgE{B=FZ<9bV#9A*}YqUYsp@XPQZ;Ug`DF( zxu&%840M_|{tRW|zI2Iy_T3?wwq25ZqrH=>&;OG{*EvU;jNb#A zN<#vfa>5ygg5!rwp+4Sn{Ed8|Rk!bCxF{HfSTDzCt2j)iSv@^(MroOo*?*G>kBxMY>r%VD^E3RLqfyXFI3!#ikuL_}4b z3=5iHG(Vspt$sa7xo;HkvbT$Vtn%t&6m%W`^2@~<8n&dtLGu_vY%nWbH$W z04!~dsV+^ZSMcPa9ot;9>Rr`^y;@-ESye^Vq(!4>c#$a>Z$-ZcprP9t_P(%Wr_^@M z|7h(?l>>e9E4HZgxA^IZ^?Hs)=h=8JtwDnT!PNGW{(0YdD&EIGy55tVbEKa0 zuB=UAir9EaL>xrRM=%$W%$D{HBuze`8>HKa;jp(v#o-edi7sl9sJx+Sy#0Q0+3ETF zyLF<(A9iqHN|#}80H&Cz9)~k65F$U?%eOcx z=gZ3Q$Erqtc*JYLK!p}2>Yv26JX=mX`I+#YP zU;Pt=r*Esp>tO)IqW)P+_ys~oGh=IG_Fvau_}kSQiYBSUYenw86C2c>8r<^G7Z@$z zJKc9*Nzh%Fm5UF0Gi)cO(!`7l_i%GU-VfeKTd`OJ7k2K61L4~ZXipDcH0Ug@I4KK_n)a*QG}`_-Idxyehu~$-F)G{QHB6=)5nsvDV-dZz?XV1y>I^i~ z&^%Gy4PQO-YSoP7R1k_^CY%t_V%dtw41Kk|Fwa22S+5X3fxt~STy0?>N4R?}G>X!Y zh(I@rWTK0)k$g0O+Mq&i@S5p7Yv*EPqU`E*)MkiaMTG_@Ovijy@ zy_6)AYS1+KbNdb;`jtMW*w*BY-IXm%6b1^LoaLJ~seN>F5!qox1@&nUfEE;c{VU&6 zv+M0i_~xD4_fG`u>33*kW5&5$>WN>d6cddbC~j|i?ru(8wp#CQ=P~qd`y$^PUl;q_ z-cHsn-Nd`Uwg_Z=b9+&&q33h*gWc!+mO}4o^%-J%;tCEiPN-rp(0$O*Rm|{Vs%Uap zs@x0e7p$nHem241V!~LK`-r&-N|mf|MDX4j>$F`Jo8nJb$wF|Q2Cq57SeH3BAtQmG z0T?GbKs2RSedEFnuM_N}Ia}<8gxf;DE{`z<9zWM*LdL2vOXvE)&`gW^QueNU1l3+> z!8BgWSva4&9jM1}OeF2~B8@J2v8SNew)tBLU!S*F9|SdJ(e5OtpXMvU(!tH2FaP}hdz z6{mL36BP#c{c*Pd0d|_;YfsD-ekX@~Ikj$YrV%PB&Zl4dkwiC6AC8?ZBY)U+I#Yel z;?gRDCFF;wk*|gra%&7EZE04jhmucQ`GqnM97|alPp4-NZ*7$&zA~f0>No_Zrk}B{oWOeJ0Ye1Doj$*B zCoBVp1+n$o%_0Ni)k!^Gt30%dy^QotvBT+#cVNvBW?8H%{CuCyY9p()jXcT0T}$g>Sh4FR<=L{Dzi92DxzMM_o; zMMmb*@>|zU>Ed7Gt$8C5E6ws(e^d*ysr?~7WGlH)uAA~|C&)JuiW)j@^dvOJU&@mE zoEW!w>>Fdtx{UDoEo2d+CQoh~$u$fQB$z0|p41f6s@%?TV#e!hz6 z>&F6=e9~2N;hBa;7i6%Z)$kowOLy=C(U?;Mz9%pq8X?+tx;Z9Z?Tg~s@trso*cC}D zr5aQbKO`EG*|~w0 zcy~{VfOsm8Dk8YTQz3B|^!Z^ZU1ncSj*o0R{eX0#M(a<@9es+{C7{dR=Id=cweR~3 z#7#c#o{fjpIy$2qVZ=``6Iw`8ADXUjS1o6I>&xx`<1}JX{okq!ftB#hbXp zowdo3kXB1U)%2SZ$x>IeLOs1I`4Uj9@~Sdh^4_I zxMza;C3&1n7_YldRYY&|{R@*&x}yfR=T&OANi3LYdQQ+$b%BlG9yi`LY_Zf7f>$%I zMm6_l!R}Q9(p5-Jc*FfOmM0%sn-gY8t84-lEXSm^b*&XAoJwp=rXCNZ&*kbHGS<)4 zP*?XM1&s>`4h}i6uO~AnXw2CKz5+W#mo6jUcws(yfY-V}sQ#^{QO}&uz2gyi+oFBh%2|Ylx;C$ z%%n&sGvb6d ztxpPh=5xIHAIqF%VJv%IHGt~~(?B5Yw!9UmgJ=Ir^oEVWJ}a=~Z0i94@aPYs2RXV~ z8H0X?*tzOUwy7ML-Wk<*?LP#0N8bUllzA0mn4BhaoyXshafGVhiFM2Af4)6S^P`nY zt1b0FJ-ygJayd(5|E^AhG{>z9D#0ezIVa>y*Z~(SfLle37twR(RMeHT@Z{|VEIbpu zuQ8`<@pkP8lB5Ay-w(vAwb;br(3Dq6>9R{nNg%%L8u)YME46wx_jhaX<9*4P z(fg>L%p>Kg*hOiBxlm?o^yogkmGS_oAmiw1Thd#?k{0+g5ZG&nW5SDg?Xd!@HeaMA7=UaKZwI5<6rBiN#B5a2(+q&86kR ziMvdKOm-YD4fSxK9JG4LSa7xVG?FgBo+A@Nf|927EQnH}{dptK>9%Xq^5;u#=5xiM zXJzc3W!A_su=~XtMGhx|%I?MK(St*2!A_j(HRlWSmc}6o+YNTp#r0U}MIo)yW;2H> zos!B1ic3DxXNKFNK`MOGS@Ph%Yif-sg&eq->sW>+--hTz1uo`E=(CkwG~KHTOO>Yj%atVr))VUKRM}8#qKs=R@jG7P zQ!9)ux-9rB5YDu{!9M7RIUYNQUpcbKO(2&}o4N#%X*^e<0f#N0{Q?P#DFd-~UW#u8mri z)RIv3pO2P9-H$v`15N7(pkmS z!a1J|xj9;OLGyVeY!SQkqwg&oOrt+`@mH+S-jx?NxsIe~uY8F!oX%f%NQ$(GzIa?MR?Yw$^ z(Gsp2VD?2`*cU;6j?V{SJMX4Jv$Olya{HuYzj|p;lk!{Y_T<9b-I+;JE9>f-bE1}- zvSp|Z#NK3^9iHXMK%UJ+F*y#lwo_TpUeUW@M`^}YO7sYooa?Si(dMP%aG1x2yO%6dEs7R+R{9Q_hs|5lAIBxZr zMjtR`Ex>R67=n_Fd1!!=MY7SL=ZBSp4?0@iC}!~XJm(L;JMLf+iF5qEx}Lt?a@wS*`$!(N zKk9W@WpR>^k=^jNT}yitW3ngxv_k2Jx)~3o&b_1~8aU%QC`3z^W2Q5eRr1O9U(J#T zo@!hkhDS@7X=1_Zr$)Ah$_}=6Aa+Ar2jf3G`u{7N!j>)|_N9_P2X5e!!Zk+jh=leG zzO^r~akmdhXTa0KySES|L1mrr{eZQc>M0JzJZOI!WJAo3VZ7o6e&UvgErkFUQqGQT zm;-JxYF|zFawjtPg+CVuYjKz*zj@P@ZIV5Hm|09*J=E?Rm5Yw!1U{XbLN}y2Rd&GP z4|Z6s9S7F!7rNAMtfn0Rf&*CWP!r{C1onNsg|Wx57(azTL>5kdQX8b1jRzvk{YoeTd)OvPmd; zNgQjp5!XO57BNhi$I|Gww|K|7Bu&QFZy4qiik2Zo^oU5aTp$@!$(N_wbcyB$W@GgdU``*|MxFUvWQOG{tr~D$SIQ@3II?8#uGkAfMv;38w>W^XKRmPLFDggtl z!g}s0q_7$MrN{Wqv7lswVb+<=-K|1-;AYXB)|&p1(pI-CB2W+Q8~$K84i6G6zbAJbV1H2LhYJ7(qqS~|%gJ^Q(~?7rh^%{}vS|`9t)6zSGI53-~0{*=Fe;Ml_SpWb4 literal 0 HcmV?d00001 diff --git a/tests/conftest.py b/tests/conftest.py index 6a6e4b94..a0090a0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ import os import time import uuid +from pathlib import Path import pytest from koheesio.logger import LoggingFactory +from koheesio.utils import get_project_root if os.name != "nt": # 'nt' is the name for Windows # force time zone to be UTC @@ -12,6 +14,12 @@ time.tzset() +PROJECT_ROOT = get_project_root() + +TEST_DATA_PATH = Path(PROJECT_ROOT / "tests" / "_data") +DELTA_FILE = Path(TEST_DATA_PATH / "readers" / "delta_file") + + @pytest.fixture(scope="session") def random_uuid(): return str(uuid.uuid4()).replace("-", "_") @@ -20,3 +28,13 @@ def random_uuid(): @pytest.fixture(scope="session") def logger(random_uuid): return LoggingFactory.get_logger(name="conf_test" + random_uuid) + + +@pytest.fixture(scope="session") +def data_path(): + return TEST_DATA_PATH.as_posix() + + +@pytest.fixture(scope="session") +def delta_file(): + return DELTA_FILE.as_posix() diff --git a/tests/pandas/readers/test_pandas_excel.py b/tests/pandas/readers/test_pandas_excel.py new file mode 100644 index 00000000..f9fc0534 --- /dev/null +++ b/tests/pandas/readers/test_pandas_excel.py @@ -0,0 +1,28 @@ +from pathlib import Path + +import pandas as pd + +from koheesio.pandas.readers.excel import ExcelReader + + +def test_excel_reader(data_path): + # Define the path to the test Excel file and the expected DataFrame + test_file = Path(data_path) / "readers" / "excel_file" / "dummy.xlsx" + expected_df = pd.DataFrame( + { + "a": ["foo", "so long"], + "b": ["bar", "and thanks"], + "c": ["baz", "for all the fish"], + "d": [None, 42], + "e": pd.to_datetime(["1/1/24", "1/1/24"]), + } + ) + + # Initialize the ExcelReader with the path to the test file and the correct sheet name + reader = ExcelReader(path=test_file, sheet_name="sheet_to_select", header=0) + + # Execute the reader + reader.execute() + + # Assert that the output DataFrame is as expected + pd.testing.assert_frame_equal(reader.output.df, expected_df, check_dtype=False) diff --git a/tests/spark/conftest.py b/tests/spark/conftest.py index fc01d6f4..f6f40f93 100644 --- a/tests/spark/conftest.py +++ b/tests/spark/conftest.py @@ -32,11 +32,6 @@ from koheesio.logger import LoggingFactory from koheesio.spark.readers.dummy import DummyReader -from koheesio.utils import get_project_root - -PROJECT_ROOT = get_project_root() -TEST_DATA_PATH = Path(PROJECT_ROOT / "tests" / "_data") -DELTA_FILE = Path(TEST_DATA_PATH / "readers" / "delta_file") @pytest.fixture(scope="session") @@ -53,11 +48,6 @@ def checkpoint_folder(tmp_path_factory, random_uuid, logger): yield fldr.as_posix() -@pytest.fixture(scope="session") -def data_path(): - return TEST_DATA_PATH.as_posix() - - @pytest.fixture(scope="session") def spark(warehouse_path, random_uuid): """Spark session fixture with Delta enabled.""" @@ -107,14 +97,14 @@ def set_env_vars(): @pytest.fixture(scope="session", autouse=True) -def setup(spark): +def setup(spark, delta_file): db_name = "klettern" if not spark.catalog.databaseExists(db_name): spark.sql(f"CREATE DATABASE {db_name}") spark.sql(f"USE {db_name}") - setup_test_data(spark=spark) + setup_test_data(spark=spark, delta_file=Path(delta_file)) yield @@ -142,8 +132,8 @@ def sample_df_to_partition(spark): @pytest.fixture -def streaming_dummy_df(spark): - setup_test_data(spark=spark) +def streaming_dummy_df(spark, delta_file): + setup_test_data(spark=spark, delta_file=Path(delta_file)) yield spark.readStream.table("delta_test_table") @@ -198,12 +188,12 @@ def sample_df_with_string_timestamp(spark): return spark.createDataFrame(data, schema) -def setup_test_data(spark): +def setup_test_data(spark, delta_file): """ Sets up test data for the Spark session. Reads a Delta file, creates a temporary view, and populates a Delta table with the view's data. """ - delta_file = DELTA_FILE.absolute().as_posix() + delta_file = delta_file.absolute().as_posix() spark.read.format("delta").load(delta_file).limit(10).createOrReplaceTempView("delta_test_view") spark.sql( dedent( diff --git a/tests/spark/integrations/dq/test_spark_expectations.py b/tests/spark/integrations/dq/test_spark_expectations.py index a8ef6fbf..259af003 100644 --- a/tests/spark/integrations/dq/test_spark_expectations.py +++ b/tests/spark/integrations/dq/test_spark_expectations.py @@ -1,10 +1,12 @@ from typing import List, Union -import pyspark import pytest -from koheesio.utils import get_project_root + +import pyspark from pyspark.sql import SparkSession +from koheesio.utils import get_project_root + PROJECT_ROOT = get_project_root() pytestmark = pytest.mark.spark diff --git a/tests/spark/readers/test_spark_excel.py b/tests/spark/readers/test_spark_excel.py new file mode 100644 index 00000000..1a6d34a8 --- /dev/null +++ b/tests/spark/readers/test_spark_excel.py @@ -0,0 +1,27 @@ +import datetime +from pathlib import Path + +from koheesio.spark.readers.excel import ExcelReader + + +def test_excel_reader(spark, data_path): + # Define the path to the test Excel file and the expected DataFrame + test_file = Path(data_path) / "readers" / "excel_file" / "dummy.xlsx" + + # Initialize the ExcelReader with the path to the test file and the correct sheet name + reader = ExcelReader(path=test_file, sheet_name="sheet_to_select", header=0) + + # Execute the reader + reader.execute() + + # Define the expected DataFrame + expected_df = spark.createDataFrame( + [ + ("foo", "bar", "baz", None, datetime.datetime(2024, 1, 1, 0, 0)), + ("so long", "and thanks", "for all the fish", 42, datetime.datetime(2024, 1, 1, 0, 0)), + ], + ["a", "b", "c", "d", "e"], + ) + + # Assert that the output DataFrame is as expected + assert sorted(reader.output.df.collect()) == sorted(expected_df.collect()) diff --git a/tests/spark/test_delta.py b/tests/spark/test_delta.py index 7b0b1008..5bad4a7d 100644 --- a/tests/spark/test_delta.py +++ b/tests/spark/test_delta.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from unittest.mock import patch import pytest @@ -74,8 +75,8 @@ def test_table(value, expected): log.info("delta test completed") -def test_delta_table_properties(spark, setup): - setup_test_data(spark=spark) +def test_delta_table_properties(spark, setup, delta_file): + setup_test_data(spark=spark, delta_file=Path(delta_file)) table_name = "delta_test_table" dt = DeltaTableStep( table=table_name, diff --git a/tests/spark/transformations/test_sql_transform.py b/tests/spark/transformations/test_sql_transform.py index d9822c57..8c392fdd 100644 --- a/tests/spark/transformations/test_sql_transform.py +++ b/tests/spark/transformations/test_sql_transform.py @@ -1,7 +1,7 @@ from textwrap import dedent import pytest -from conftest import TEST_DATA_PATH +from tests.conftest import TEST_DATA_PATH from koheesio.logger import LoggingFactory from koheesio.spark.transformations.sql_transform import SqlTransform diff --git a/tests/spark/writers/delta/test_delta_writer.py b/tests/spark/writers/delta/test_delta_writer.py index 21d02725..4a360692 100644 --- a/tests/spark/writers/delta/test_delta_writer.py +++ b/tests/spark/writers/delta/test_delta_writer.py @@ -4,14 +4,17 @@ import pytest from conftest import await_job_completion from delta import DeltaTable + +from pydantic import ValidationError + +from pyspark.sql import functions as F + from koheesio.spark import AnalysisException from koheesio.spark.delta import DeltaTableStep from koheesio.spark.writers import BatchOutputMode, StreamingOutputMode from koheesio.spark.writers.delta import DeltaTableStreamWriter, DeltaTableWriter from koheesio.spark.writers.delta.utils import log_clauses from koheesio.spark.writers.stream import Trigger -from pydantic import ValidationError -from pyspark.sql import functions as F pytestmark = pytest.mark.spark From ef7c9343d6a307cd43c62dbd10998d00fc5f4936 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:09:04 +0200 Subject: [PATCH 02/17] Ability to read Excel files - introduces pandas module - abstracted Reader class to a common BaseReader (as part of models module) - re-implemented spark.readers.Reader based on BaseReader - introduces ExcelReader at koheesio.pandas.readers.excel.ExcelReader - introduces ExcelReader at koheesio.spark.reader.excel.ExcelReader - added unittests to cover the above - added excel extra dependency - added docs --- pyproject.toml | 7 ++- src/koheesio/pandas/__init__.py | 7 +-- src/koheesio/spark/utils.py | 27 ++++++++++ tests/pandas/readers/test_pandas_excel.py | 2 +- tests/spark/test_spark_utils.py | 53 +++++++++++++++++++ .../transformations/test_sql_transform.py | 3 +- tests/utils/test_utils.py | 20 ------- 7 files changed, 92 insertions(+), 27 deletions(-) create mode 100644 tests/spark/test_spark_utils.py diff --git a/pyproject.toml b/pyproject.toml index 3c14c130..f49d0e85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -287,6 +287,9 @@ matrix.version.extra-dependencies = [ { value = "spark-expectations>=2.1.0", if = [ "pyspark33", ] }, + { value = "pandas<2", if = [ + "pyspark33", + ] }, { value = "pyspark>=3.4,<3.5", if = [ "pyspark34", ] }, @@ -392,7 +395,7 @@ Available scripts: - `coverage` or `cov` - run the test suite with coverage. """ path = ".venv" -python = "3.10" +python = "3.9" template = "default" features = [ "async", @@ -408,7 +411,7 @@ features = [ "test", "docs", ] -extra-dependencies = ["pyspark==3.4.*"] +extra-dependencies = ["pyspark==3.3.*", "pandas<2"] ### ~~~~~~~~~~~~~~~~~~ ### diff --git a/src/koheesio/pandas/__init__.py b/src/koheesio/pandas/__init__.py index b8494479..a9d324aa 100644 --- a/src/koheesio/pandas/__init__.py +++ b/src/koheesio/pandas/__init__.py @@ -7,10 +7,11 @@ from typing import Optional from abc import ABC -from pandas import DataFrame - from koheesio import Step, StepOutput from koheesio.models import Field +from koheesio.spark.utils import import_pandas_based_on_pyspark_version + +pandas = import_pandas_based_on_pyspark_version() class PandasStep(Step, ABC): @@ -23,4 +24,4 @@ class PandasStep(Step, ABC): class Output(StepOutput): """Output class for PandasStep""" - df: Optional[DataFrame] = Field(default=None, description="The Pandas DataFrame") + df: Optional[pandas.DataFrame] = Field(default=None, description="The Pandas DataFrame") diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py index cffdf75d..c0cf4efb 100644 --- a/src/koheesio/spark/utils.py +++ b/src/koheesio/spark/utils.py @@ -29,6 +29,7 @@ __all__ = [ "SparkDatatype", "get_spark_minor_version", + "import_pandas_based_on_pyspark_version", "on_databricks", "schema_struct_to_schema_str", "spark_data_type_is_array", @@ -177,3 +178,29 @@ def schema_struct_to_schema_str(schema: StructType) -> str: if not schema: return "" return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) + + +def import_pandas_based_on_pyspark_version(): + """ + This function checks the installed version of PySpark and then tries to import the appropriate version of pandas. + If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version + of pandas should be installed. + """ + pyspark_version = get_spark_minor_version() + + if pyspark_version < 3.4: + try: + import pandas as pd + + assert pd.__version__ < "2" + except (ImportError, AssertionError): + raise ImportError("For PySpark <3.4, please install Pandas version < 2") + else: + try: + import pandas as pd + + assert pd.__version__ >= "2" + except (ImportError, AssertionError): + raise ImportError("For PySpark versions other than 3.3, please install Pandas version >= 2") + + return pd diff --git a/tests/pandas/readers/test_pandas_excel.py b/tests/pandas/readers/test_pandas_excel.py index f9fc0534..4e00a235 100644 --- a/tests/pandas/readers/test_pandas_excel.py +++ b/tests/pandas/readers/test_pandas_excel.py @@ -14,7 +14,7 @@ def test_excel_reader(data_path): "b": ["bar", "and thanks"], "c": ["baz", "for all the fish"], "d": [None, 42], - "e": pd.to_datetime(["1/1/24", "1/1/24"]), + "e": pd.to_datetime(["1/1/24", "1/1/24"], format="%m/%d/%y"), } ) diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py new file mode 100644 index 00000000..6455bea6 --- /dev/null +++ b/tests/spark/test_spark_utils.py @@ -0,0 +1,53 @@ +from os import environ +from unittest.mock import patch + +import pytest + +from pyspark.sql.types import StringType, StructField, StructType + +from koheesio.spark.utils import ( + import_pandas_based_on_pyspark_version, + on_databricks, + schema_struct_to_schema_str, +) + + +def test_schema_struct_to_schema_str(): + struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) + val = schema_struct_to_schema_str(struct_schema) + assert val == "a STRING,\nb STRING" + assert schema_struct_to_schema_str(None) == "" + + +@pytest.mark.parametrize( + "env_var_value, expected_result", + [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], +) +def test_on_databricks(env_var_value, expected_result): + if env_var_value is not None: + with patch.dict(environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): + assert on_databricks() == expected_result + else: + with patch.dict(environ, clear=True): + assert on_databricks() == expected_result + + +@pytest.mark.parametrize( + "spark_version, pandas_version, expected_error", + [ + (3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error + (3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error + (3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error + (3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error + ], +) +def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): + with ( + patch("koheesio.spark.utils.get_spark_minor_version", return_value=spark_version), + patch("pandas.__version__", new=pandas_version), + ): + if expected_error: + with pytest.raises(expected_error): + import_pandas_based_on_pyspark_version() + else: + import_pandas_based_on_pyspark_version() # This should not raise an error diff --git a/tests/spark/transformations/test_sql_transform.py b/tests/spark/transformations/test_sql_transform.py index 8c392fdd..0ffe73f4 100644 --- a/tests/spark/transformations/test_sql_transform.py +++ b/tests/spark/transformations/test_sql_transform.py @@ -1,11 +1,12 @@ from textwrap import dedent import pytest -from tests.conftest import TEST_DATA_PATH from koheesio.logger import LoggingFactory from koheesio.spark.transformations.sql_transform import SqlTransform +from tests.conftest import TEST_DATA_PATH + pytestmark = pytest.mark.spark log = LoggingFactory.get_logger(name="test_sql_transform") diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 90398914..ead642f4 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -25,26 +25,6 @@ def test_import_class(): assert import_class("datetime.datetime") == datetime.datetime -@pytest.mark.parametrize( - "env_var_value, expected_result", - [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], -) -def test_on_databricks(env_var_value, expected_result): - if env_var_value is not None: - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): - assert on_databricks() == expected_result - else: - with patch.dict(os.environ, clear=True): - assert on_databricks() == expected_result - - -def test_schema_struct_to_schema_str(): - struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) - val = schema_struct_to_schema_str(struct_schema) - assert val == "a STRING,\nb STRING" - assert schema_struct_to_schema_str(None) == "" - - def test_get_random_string(): assert get_random_string(10) != get_random_string(10) assert len(get_random_string(10)) == 10 From 0cd0567f39db6f48376f71164c85753935c7dcfd Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:09:38 +0200 Subject: [PATCH 03/17] Revert "Ability to read Excel files" This reverts commit ef7c9343d6a307cd43c62dbd10998d00fc5f4936. --- pyproject.toml | 7 +-- src/koheesio/pandas/__init__.py | 7 ++- src/koheesio/spark/utils.py | 27 ---------- tests/pandas/readers/test_pandas_excel.py | 2 +- tests/spark/test_spark_utils.py | 53 ------------------- .../transformations/test_sql_transform.py | 3 +- tests/utils/test_utils.py | 20 +++++++ 7 files changed, 27 insertions(+), 92 deletions(-) delete mode 100644 tests/spark/test_spark_utils.py diff --git a/pyproject.toml b/pyproject.toml index f49d0e85..3c14c130 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -287,9 +287,6 @@ matrix.version.extra-dependencies = [ { value = "spark-expectations>=2.1.0", if = [ "pyspark33", ] }, - { value = "pandas<2", if = [ - "pyspark33", - ] }, { value = "pyspark>=3.4,<3.5", if = [ "pyspark34", ] }, @@ -395,7 +392,7 @@ Available scripts: - `coverage` or `cov` - run the test suite with coverage. """ path = ".venv" -python = "3.9" +python = "3.10" template = "default" features = [ "async", @@ -411,7 +408,7 @@ features = [ "test", "docs", ] -extra-dependencies = ["pyspark==3.3.*", "pandas<2"] +extra-dependencies = ["pyspark==3.4.*"] ### ~~~~~~~~~~~~~~~~~~ ### diff --git a/src/koheesio/pandas/__init__.py b/src/koheesio/pandas/__init__.py index a9d324aa..b8494479 100644 --- a/src/koheesio/pandas/__init__.py +++ b/src/koheesio/pandas/__init__.py @@ -7,11 +7,10 @@ from typing import Optional from abc import ABC +from pandas import DataFrame + from koheesio import Step, StepOutput from koheesio.models import Field -from koheesio.spark.utils import import_pandas_based_on_pyspark_version - -pandas = import_pandas_based_on_pyspark_version() class PandasStep(Step, ABC): @@ -24,4 +23,4 @@ class PandasStep(Step, ABC): class Output(StepOutput): """Output class for PandasStep""" - df: Optional[pandas.DataFrame] = Field(default=None, description="The Pandas DataFrame") + df: Optional[DataFrame] = Field(default=None, description="The Pandas DataFrame") diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py index c0cf4efb..cffdf75d 100644 --- a/src/koheesio/spark/utils.py +++ b/src/koheesio/spark/utils.py @@ -29,7 +29,6 @@ __all__ = [ "SparkDatatype", "get_spark_minor_version", - "import_pandas_based_on_pyspark_version", "on_databricks", "schema_struct_to_schema_str", "spark_data_type_is_array", @@ -178,29 +177,3 @@ def schema_struct_to_schema_str(schema: StructType) -> str: if not schema: return "" return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) - - -def import_pandas_based_on_pyspark_version(): - """ - This function checks the installed version of PySpark and then tries to import the appropriate version of pandas. - If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version - of pandas should be installed. - """ - pyspark_version = get_spark_minor_version() - - if pyspark_version < 3.4: - try: - import pandas as pd - - assert pd.__version__ < "2" - except (ImportError, AssertionError): - raise ImportError("For PySpark <3.4, please install Pandas version < 2") - else: - try: - import pandas as pd - - assert pd.__version__ >= "2" - except (ImportError, AssertionError): - raise ImportError("For PySpark versions other than 3.3, please install Pandas version >= 2") - - return pd diff --git a/tests/pandas/readers/test_pandas_excel.py b/tests/pandas/readers/test_pandas_excel.py index 4e00a235..f9fc0534 100644 --- a/tests/pandas/readers/test_pandas_excel.py +++ b/tests/pandas/readers/test_pandas_excel.py @@ -14,7 +14,7 @@ def test_excel_reader(data_path): "b": ["bar", "and thanks"], "c": ["baz", "for all the fish"], "d": [None, 42], - "e": pd.to_datetime(["1/1/24", "1/1/24"], format="%m/%d/%y"), + "e": pd.to_datetime(["1/1/24", "1/1/24"]), } ) diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py deleted file mode 100644 index 6455bea6..00000000 --- a/tests/spark/test_spark_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -from os import environ -from unittest.mock import patch - -import pytest - -from pyspark.sql.types import StringType, StructField, StructType - -from koheesio.spark.utils import ( - import_pandas_based_on_pyspark_version, - on_databricks, - schema_struct_to_schema_str, -) - - -def test_schema_struct_to_schema_str(): - struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) - val = schema_struct_to_schema_str(struct_schema) - assert val == "a STRING,\nb STRING" - assert schema_struct_to_schema_str(None) == "" - - -@pytest.mark.parametrize( - "env_var_value, expected_result", - [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], -) -def test_on_databricks(env_var_value, expected_result): - if env_var_value is not None: - with patch.dict(environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): - assert on_databricks() == expected_result - else: - with patch.dict(environ, clear=True): - assert on_databricks() == expected_result - - -@pytest.mark.parametrize( - "spark_version, pandas_version, expected_error", - [ - (3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error - (3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error - (3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error - (3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error - ], -) -def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): - with ( - patch("koheesio.spark.utils.get_spark_minor_version", return_value=spark_version), - patch("pandas.__version__", new=pandas_version), - ): - if expected_error: - with pytest.raises(expected_error): - import_pandas_based_on_pyspark_version() - else: - import_pandas_based_on_pyspark_version() # This should not raise an error diff --git a/tests/spark/transformations/test_sql_transform.py b/tests/spark/transformations/test_sql_transform.py index 0ffe73f4..8c392fdd 100644 --- a/tests/spark/transformations/test_sql_transform.py +++ b/tests/spark/transformations/test_sql_transform.py @@ -1,12 +1,11 @@ from textwrap import dedent import pytest +from tests.conftest import TEST_DATA_PATH from koheesio.logger import LoggingFactory from koheesio.spark.transformations.sql_transform import SqlTransform -from tests.conftest import TEST_DATA_PATH - pytestmark = pytest.mark.spark log = LoggingFactory.get_logger(name="test_sql_transform") diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index ead642f4..90398914 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -25,6 +25,26 @@ def test_import_class(): assert import_class("datetime.datetime") == datetime.datetime +@pytest.mark.parametrize( + "env_var_value, expected_result", + [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], +) +def test_on_databricks(env_var_value, expected_result): + if env_var_value is not None: + with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): + assert on_databricks() == expected_result + else: + with patch.dict(os.environ, clear=True): + assert on_databricks() == expected_result + + +def test_schema_struct_to_schema_str(): + struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) + val = schema_struct_to_schema_str(struct_schema) + assert val == "a STRING,\nb STRING" + assert schema_struct_to_schema_str(None) == "" + + def test_get_random_string(): assert get_random_string(10) != get_random_string(10) assert len(get_random_string(10)) == 10 From 7296b7e6f43f6ecd6bbde5524c98ce68d2cb8337 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 4 Jun 2024 16:11:01 +0200 Subject: [PATCH 04/17] fixes --- pyproject.toml | 7 ++- src/koheesio/pandas/__init__.py | 7 +-- src/koheesio/spark/utils.py | 27 ++++++++++ tests/pandas/readers/test_pandas_excel.py | 2 +- tests/spark/test_spark_utils.py | 53 +++++++++++++++++++ .../transformations/test_sql_transform.py | 3 +- tests/utils/test_utils.py | 20 ------- 7 files changed, 92 insertions(+), 27 deletions(-) create mode 100644 tests/spark/test_spark_utils.py diff --git a/pyproject.toml b/pyproject.toml index 3c14c130..f49d0e85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -287,6 +287,9 @@ matrix.version.extra-dependencies = [ { value = "spark-expectations>=2.1.0", if = [ "pyspark33", ] }, + { value = "pandas<2", if = [ + "pyspark33", + ] }, { value = "pyspark>=3.4,<3.5", if = [ "pyspark34", ] }, @@ -392,7 +395,7 @@ Available scripts: - `coverage` or `cov` - run the test suite with coverage. """ path = ".venv" -python = "3.10" +python = "3.9" template = "default" features = [ "async", @@ -408,7 +411,7 @@ features = [ "test", "docs", ] -extra-dependencies = ["pyspark==3.4.*"] +extra-dependencies = ["pyspark==3.3.*", "pandas<2"] ### ~~~~~~~~~~~~~~~~~~ ### diff --git a/src/koheesio/pandas/__init__.py b/src/koheesio/pandas/__init__.py index b8494479..a9d324aa 100644 --- a/src/koheesio/pandas/__init__.py +++ b/src/koheesio/pandas/__init__.py @@ -7,10 +7,11 @@ from typing import Optional from abc import ABC -from pandas import DataFrame - from koheesio import Step, StepOutput from koheesio.models import Field +from koheesio.spark.utils import import_pandas_based_on_pyspark_version + +pandas = import_pandas_based_on_pyspark_version() class PandasStep(Step, ABC): @@ -23,4 +24,4 @@ class PandasStep(Step, ABC): class Output(StepOutput): """Output class for PandasStep""" - df: Optional[DataFrame] = Field(default=None, description="The Pandas DataFrame") + df: Optional[pandas.DataFrame] = Field(default=None, description="The Pandas DataFrame") diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py index cffdf75d..c0cf4efb 100644 --- a/src/koheesio/spark/utils.py +++ b/src/koheesio/spark/utils.py @@ -29,6 +29,7 @@ __all__ = [ "SparkDatatype", "get_spark_minor_version", + "import_pandas_based_on_pyspark_version", "on_databricks", "schema_struct_to_schema_str", "spark_data_type_is_array", @@ -177,3 +178,29 @@ def schema_struct_to_schema_str(schema: StructType) -> str: if not schema: return "" return ",\n".join([f"{field.name} {field.dataType.typeName().upper()}" for field in schema.fields]) + + +def import_pandas_based_on_pyspark_version(): + """ + This function checks the installed version of PySpark and then tries to import the appropriate version of pandas. + If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version + of pandas should be installed. + """ + pyspark_version = get_spark_minor_version() + + if pyspark_version < 3.4: + try: + import pandas as pd + + assert pd.__version__ < "2" + except (ImportError, AssertionError): + raise ImportError("For PySpark <3.4, please install Pandas version < 2") + else: + try: + import pandas as pd + + assert pd.__version__ >= "2" + except (ImportError, AssertionError): + raise ImportError("For PySpark versions other than 3.3, please install Pandas version >= 2") + + return pd diff --git a/tests/pandas/readers/test_pandas_excel.py b/tests/pandas/readers/test_pandas_excel.py index f9fc0534..4e00a235 100644 --- a/tests/pandas/readers/test_pandas_excel.py +++ b/tests/pandas/readers/test_pandas_excel.py @@ -14,7 +14,7 @@ def test_excel_reader(data_path): "b": ["bar", "and thanks"], "c": ["baz", "for all the fish"], "d": [None, 42], - "e": pd.to_datetime(["1/1/24", "1/1/24"]), + "e": pd.to_datetime(["1/1/24", "1/1/24"], format="%m/%d/%y"), } ) diff --git a/tests/spark/test_spark_utils.py b/tests/spark/test_spark_utils.py new file mode 100644 index 00000000..6455bea6 --- /dev/null +++ b/tests/spark/test_spark_utils.py @@ -0,0 +1,53 @@ +from os import environ +from unittest.mock import patch + +import pytest + +from pyspark.sql.types import StringType, StructField, StructType + +from koheesio.spark.utils import ( + import_pandas_based_on_pyspark_version, + on_databricks, + schema_struct_to_schema_str, +) + + +def test_schema_struct_to_schema_str(): + struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) + val = schema_struct_to_schema_str(struct_schema) + assert val == "a STRING,\nb STRING" + assert schema_struct_to_schema_str(None) == "" + + +@pytest.mark.parametrize( + "env_var_value, expected_result", + [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], +) +def test_on_databricks(env_var_value, expected_result): + if env_var_value is not None: + with patch.dict(environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): + assert on_databricks() == expected_result + else: + with patch.dict(environ, clear=True): + assert on_databricks() == expected_result + + +@pytest.mark.parametrize( + "spark_version, pandas_version, expected_error", + [ + (3.3, "1.2.3", None), # PySpark 3.3, pandas < 2, should not raise an error + (3.4, "2.3.4", None), # PySpark not 3.3, pandas >= 2, should not raise an error + (3.3, "2.3.4", ImportError), # PySpark 3.3, pandas >= 2, should raise an error + (3.4, "1.2.3", ImportError), # PySpark not 3.3, pandas < 2, should raise an error + ], +) +def test_import_pandas_based_on_pyspark_version(spark_version, pandas_version, expected_error): + with ( + patch("koheesio.spark.utils.get_spark_minor_version", return_value=spark_version), + patch("pandas.__version__", new=pandas_version), + ): + if expected_error: + with pytest.raises(expected_error): + import_pandas_based_on_pyspark_version() + else: + import_pandas_based_on_pyspark_version() # This should not raise an error diff --git a/tests/spark/transformations/test_sql_transform.py b/tests/spark/transformations/test_sql_transform.py index 8c392fdd..0ffe73f4 100644 --- a/tests/spark/transformations/test_sql_transform.py +++ b/tests/spark/transformations/test_sql_transform.py @@ -1,11 +1,12 @@ from textwrap import dedent import pytest -from tests.conftest import TEST_DATA_PATH from koheesio.logger import LoggingFactory from koheesio.spark.transformations.sql_transform import SqlTransform +from tests.conftest import TEST_DATA_PATH + pytestmark = pytest.mark.spark log = LoggingFactory.get_logger(name="test_sql_transform") diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 90398914..ead642f4 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -25,26 +25,6 @@ def test_import_class(): assert import_class("datetime.datetime") == datetime.datetime -@pytest.mark.parametrize( - "env_var_value, expected_result", - [("lts_11_spark_3_scala_2.12", True), ("unit_test", True), (None, False)], -) -def test_on_databricks(env_var_value, expected_result): - if env_var_value is not None: - with patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": env_var_value}): - assert on_databricks() == expected_result - else: - with patch.dict(os.environ, clear=True): - assert on_databricks() == expected_result - - -def test_schema_struct_to_schema_str(): - struct_schema = StructType([StructField("a", StringType()), StructField("b", StringType())]) - val = schema_struct_to_schema_str(struct_schema) - assert val == "a STRING,\nb STRING" - assert schema_struct_to_schema_str(None) == "" - - def test_get_random_string(): assert get_random_string(10) != get_random_string(10) assert len(get_random_string(10)) == 10 From 969c901bc3a9226e0cde0c2c59f1a88947af78d3 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:22:31 +0200 Subject: [PATCH 05/17] fixes --- pyproject.toml | 4 ++-- src/koheesio/spark/utils.py | 25 +++++++++---------------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f49d0e85..7bc894ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -395,7 +395,7 @@ Available scripts: - `coverage` or `cov` - run the test suite with coverage. """ path = ".venv" -python = "3.9" +python = "3.10" template = "default" features = [ "async", @@ -411,7 +411,7 @@ features = [ "test", "docs", ] -extra-dependencies = ["pyspark==3.3.*", "pandas<2"] +extra-dependencies = ["pyspark==3.4.*"] ### ~~~~~~~~~~~~~~~~~~ ### diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py index c0cf4efb..7d42b0c0 100644 --- a/src/koheesio/spark/utils.py +++ b/src/koheesio/spark/utils.py @@ -186,21 +186,14 @@ def import_pandas_based_on_pyspark_version(): If the correct version of pandas is not installed, it raises an ImportError with a message indicating which version of pandas should be installed. """ - pyspark_version = get_spark_minor_version() + try: + import pandas as pd + pyspark_version = get_spark_minor_version() + pandas_version = pd.__version__ - if pyspark_version < 3.4: - try: - import pandas as pd + if (pyspark_version < 3.4 and pandas_version >= '2') or (pyspark_version >= 3.4 and pandas_version < '2'): + raise ImportError(f"For PySpark {pyspark_version}, please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}") - assert pd.__version__ < "2" - except (ImportError, AssertionError): - raise ImportError("For PySpark <3.4, please install Pandas version < 2") - else: - try: - import pandas as pd - - assert pd.__version__ >= "2" - except (ImportError, AssertionError): - raise ImportError("For PySpark versions other than 3.3, please install Pandas version >= 2") - - return pd + return pd + except ImportError: + raise ImportError("Pandas module is not installed.") From c06af46205d85e2db0ee12b3e5d0f81e2a150ca7 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:23:09 +0200 Subject: [PATCH 06/17] formatting --- src/koheesio/spark/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/koheesio/spark/utils.py b/src/koheesio/spark/utils.py index 7d42b0c0..b382c4b1 100644 --- a/src/koheesio/spark/utils.py +++ b/src/koheesio/spark/utils.py @@ -188,12 +188,16 @@ def import_pandas_based_on_pyspark_version(): """ try: import pandas as pd + pyspark_version = get_spark_minor_version() pandas_version = pd.__version__ - if (pyspark_version < 3.4 and pandas_version >= '2') or (pyspark_version >= 3.4 and pandas_version < '2'): - raise ImportError(f"For PySpark {pyspark_version}, please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}") + if (pyspark_version < 3.4 and pandas_version >= "2") or (pyspark_version >= 3.4 and pandas_version < "2"): + raise ImportError( + f"For PySpark {pyspark_version}, " + f"please install Pandas version {'< 2' if pyspark_version < 3.4 else '>= 2'}" + ) return pd - except ImportError: - raise ImportError("Pandas module is not installed.") + except ImportError as e: + raise ImportError("Pandas module is not installed.") from e From 4d742c897125b07dcb8d8e72a0e187aa0e8e7fb5 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 10:43:20 +0200 Subject: [PATCH 07/17] fix for test workflow --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c19a130e..ea76567d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,8 +39,8 @@ jobs: - name: Check changes id: check run: | - echo "python_changed=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | grep '\.py$')" >> "$GITHUB_OUTPUT" - echo "toml_changed=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | grep '\.toml$')" >> "$GITHUB_OUTPUT" + echo "python_changed=$(git diff --name-only ${{ github.event.pull_request.base.ref }} ${{ github.event.after }} | grep '\.py$')" >> "$GITHUB_OUTPUT" + echo "toml_changed=$(git diff --name-only ${{ github.event.pull_request.base.ref }} ${{ github.event.after }} | grep '\.toml$')" >> "$GITHUB_OUTPUT" tests: needs: check_changes From dc813d3628e57fc72a592ce241940873f85ffabb Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 10:47:53 +0200 Subject: [PATCH 08/17] cleanup --- pyproject.toml | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7bc894ae..63f988fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,7 +186,6 @@ features = [ [tool.hatch.envs.default.scripts] # TODO: add scripts section based on Makefile # TODO: add bandit -# TODO: move scripts from linting and style here # Code Quality commands black-check = "black --check --diff ." black-fmt = "black ." @@ -215,15 +214,7 @@ non-spark-tests = "test -m \"not spark\"" # scripts.run = "- log-versions && pytest tests/ {env:HATCH_TEST_ARGS:} {args}" # run ="echo {args}" # run = "- pytest tests/ {env:HATCH_TEST_ARGS:} {args}" -# run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" -# cov-combine = "coverage combine" -# cov-report = "coverage report" # log-versions = "python --version && {env:HATCH_UV} pip freeze | grep pyspark" -# -# -# -# coverage = "- pytest tests/ {env:HATCH_TEST_ARGS:} {args} --cov=koheesio --cov-report=html --cov-report=term-missing --cov-fail-under=90" -# cov = "coverage" ### ~~~~~~~~~~~~~~~~~~~~~ ### From bb25c57a9adff22f6d37478f30d62297160da3e0 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 10:51:44 +0200 Subject: [PATCH 09/17] fix for workflow --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ea76567d..86a53de3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,6 +36,8 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Check changes id: check run: | From 4f53a1668bef637ccad38f99d2c9024b0e36a812 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 10:54:40 +0200 Subject: [PATCH 10/17] fix for workflow --- .github/workflows/test.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 86a53de3..9cb0010c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,11 +38,15 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + ref: ${{ github.head_ref }} + - name: Fetch main branch + run: git fetch origin main:main - name: Check changes id: check run: | - echo "python_changed=$(git diff --name-only ${{ github.event.pull_request.base.ref }} ${{ github.event.after }} | grep '\.py$')" >> "$GITHUB_OUTPUT" - echo "toml_changed=$(git diff --name-only ${{ github.event.pull_request.base.ref }} ${{ github.event.after }} | grep '\.toml$')" >> "$GITHUB_OUTPUT" + BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} + echo "python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$')" >> "$GITHUB_OUTPUT" + echo "toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$')" >> "$GITHUB_OUTPUT" tests: needs: check_changes From 9f2f92d5cd7bc4fd2bbd01d542f310cd832bcb99 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:01:58 +0200 Subject: [PATCH 11/17] fix for workflow --- .github/workflows/test.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9cb0010c..75741096 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,8 +45,10 @@ jobs: id: check run: | BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} - echo "python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$')" >> "$GITHUB_OUTPUT" - echo "toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$')" >> "$GITHUB_OUTPUT" + python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$') + toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') + echo "::set-output name=python_changed::$python_changed" + echo "::set-output name=toml_changed::$toml_changed" tests: needs: check_changes From 742cfbc59d6235ac2e6790441cf2a8cbf0c0af7d Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:04:34 +0200 Subject: [PATCH 12/17] fix for workflow --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 75741096..40cbb767 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,8 +47,8 @@ jobs: BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$') toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') - echo "::set-output name=python_changed::$python_changed" - echo "::set-output name=toml_changed::$toml_changed" + echo "python_changed=$python_changed" >> $GITHUB_ENV + echo "toml_changed=$toml_changed" >> $GITHUB_ENV tests: needs: check_changes From 5a28bfbdb8ac53b05437dd31c651b9a656db08f0 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:05:58 +0200 Subject: [PATCH 13/17] fix for workflow --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 40cbb767..c23f8dd7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,8 +47,8 @@ jobs: BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$') toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') - echo "python_changed=$python_changed" >> $GITHUB_ENV - echo "toml_changed=$toml_changed" >> $GITHUB_ENV + echo "python_changed=$python_changed" >> $GITHUB_OUTPUT + echo "toml_changed=$toml_changed" >> $GITHUB_OUTPUT tests: needs: check_changes From f591b5c8fd22dd00b212c3cf930fb6e003badd82 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:07:06 +0200 Subject: [PATCH 14/17] fix for workflow --- .github/workflows/test.yml | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c23f8dd7..e5518885 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,11 +44,16 @@ jobs: - name: Check changes id: check run: | - BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} - python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$') - toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') - echo "python_changed=$python_changed" >> $GITHUB_OUTPUT - echo "toml_changed=$toml_changed" >> $GITHUB_OUTPUT + # Set the base reference for the git diff + BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} + + # Check for changes in Python or TOML files + python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$') + toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') + + # Write the changes to the GITHUB_OUTPUT environment file + echo "python_changed=$python_changed" >> $GITHUB_OUTPUT + echo "toml_changed=$toml_changed" >> $GITHUB_OUTPUT tests: needs: check_changes From b2c8aebb589b0d9abc07308cfcc2f5bc7a6eeb34 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:09:42 +0200 Subject: [PATCH 15/17] fix for workflow --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e5518885..62eb51ac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,8 +52,8 @@ jobs: toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') # Write the changes to the GITHUB_OUTPUT environment file - echo "python_changed=$python_changed" >> $GITHUB_OUTPUT - echo "toml_changed=$toml_changed" >> $GITHUB_OUTPUT + echo "python_changed=$python_changed" >> "$GITHUB_OUTPUT" + echo "toml_changed=$toml_changed" >> "$GITHUB_OUTPUT" tests: needs: check_changes From c8643f7d64da7fa64fc8bd869dc9dc94ad223c43 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:19:17 +0200 Subject: [PATCH 16/17] fix for workflow --- .github/workflows/test.yml | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 62eb51ac..5aae3619 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -47,17 +47,20 @@ jobs: # Set the base reference for the git diff BASE_REF=${{ github.event.pull_request.base.ref || 'main' }} - # Check for changes in Python or TOML files - python_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.py$') - toml_changed=$(git diff --name-only $BASE_REF ${{ github.event.after }} | grep '\.toml$') + # Check for changes in this PR / commit + git_diff_output=$(git diff --name-only $BASE_REF ${{ github.event.after }}) + + # Count the number of changes to Python and TOML files + python_changed=$(echo "$git_diff_output" | grep '\.py$' | wc -l) + toml_changed=$(echo "$git_diff_output" | grep '\.toml$' | wc -l) # Write the changes to the GITHUB_OUTPUT environment file - echo "python_changed=$python_changed" >> "$GITHUB_OUTPUT" - echo "toml_changed=$toml_changed" >> "$GITHUB_OUTPUT" + echo "python_changed=$python_changed" >> $GITHUB_OUTPUT + echo "toml_changed=$toml_changed" >> $GITHUB_OUTPUT tests: needs: check_changes - if: needs.check_changes.outputs.python_changed != '' || needs.check_changes.outputs.toml_changed != '' || github.event_name == 'workflow_dispatch' + if: needs.check_changes.outputs.python_changed > 0 || needs.check_changes.outputs.toml_changed > 0 || github.event_name == 'workflow_dispatch' name: Python ${{ matrix.python-version }} with PySpark ${{ matrix.pyspark-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} runs-on: ${{ matrix.os }} From 303ea22873c1b598783c3c6a628772c11459ebd5 Mon Sep 17 00:00:00 2001 From: Danny Meijer <10511979+dannymeijer@users.noreply.github.com> Date: Fri, 7 Jun 2024 11:40:09 +0200 Subject: [PATCH 17/17] fix failing unit test --- .../spark/transformations/test_sql_transform.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/spark/transformations/test_sql_transform.py b/tests/spark/transformations/test_sql_transform.py index 0ffe73f4..9834d9c3 100644 --- a/tests/spark/transformations/test_sql_transform.py +++ b/tests/spark/transformations/test_sql_transform.py @@ -1,3 +1,4 @@ +from pathlib import Path from textwrap import dedent import pytest @@ -5,13 +6,16 @@ from koheesio.logger import LoggingFactory from koheesio.spark.transformations.sql_transform import SqlTransform -from tests.conftest import TEST_DATA_PATH - pytestmark = pytest.mark.spark log = LoggingFactory.get_logger(name="test_sql_transform") +@pytest.fixture +def test_data_path(data_path) -> Path: + return Path(data_path) / "transformations" + + @pytest.mark.parametrize( "input_values,expected", [ @@ -49,7 +53,7 @@ # input values dict( table_name="dummy_table", - sql_path=TEST_DATA_PATH / "transformations" / "dummy.sql", + sql_path="dummy.sql", ), # expected output {"id": 0, "incremented_id": 1}, @@ -59,14 +63,16 @@ # input values dict( table_name="dummy_table", - sql_path=str((TEST_DATA_PATH / "transformations" / "dummy.sql").as_posix()), + sql_path="dummy.sql", ), # expected output {"id": 0, "incremented_id": 1}, ), ], ) -def test_sql_transform(input_values, expected, dummy_df): +def test_sql_transform(input_values, expected, dummy_df, test_data_path): + if sql_path := input_values.get("sql_path"): + input_values["sql_path"] = str((test_data_path / sql_path).as_posix()) result = SqlTransform(**input_values).transform(dummy_df) actual = result.head().asDict()