diff --git a/daft/pyspark/__init__.py b/daft/pyspark/__init__.py new file mode 100644 index 0000000000..d8b557752e --- /dev/null +++ b/daft/pyspark/__init__.py @@ -0,0 +1,94 @@ +"""# PySpark. + +The `daft.pyspark` module provides a way to create a PySpark session that can be run locally or backed by a ray cluster. + +This serves as a way to run the daft query engine, but with a spark compatible API. + + +## Example + +```py +from daft.pyspark import SparkSession +from pyspark.sql.functions import col + +# create a local spark session +spark = SparkSession.builder.local().getOrCreate() + +# alternatively, connect to a ray cluster +spark = SparkSession.builder.remote("ray://:10001").getOrCreate() + +# use spark as you would with the native spark library, but with a daft backend! + +spark.createDataFrame([{"hello": "world"}]).select(col("hello")).show() + +# stop the spark session +spark.stop() +``` +""" + +from daft.daft import connect_start +from pyspark.sql import SparkSession as PySparkSession + + +class Builder: + def __init__(self): + self._builder = PySparkSession.builder + + def local(self): + self._connection = connect_start() + url = f"sc://0.0.0.0:{self._connection.port()}" + self._builder = PySparkSession.builder.remote(url) + return self + + def remote(self, url): + if url.startswith("ray://"): + import daft + + if url.startswith("ray://localhost"): + daft.context.set_runner_ray(noop_if_initialized=True) + else: + daft.context.set_runner_ray(address=url, noop_if_initialized=True) + self._connection = connect_start() + url = f"sc://0.0.0.0:{self._connection.port()}" + self._builder = PySparkSession.builder.remote(url) + return self + else: + self._builder = PySparkSession.builder.remote(url) + return self + + def getOrCreate(self): + return SparkSession(self._builder.getOrCreate(), self._connection) + + def __getattr__(self, name): + attr = getattr(self._builder, name) + if callable(attr): + + def wrapped(*args, **kwargs): + result = attr(*args, **kwargs) + # If result is the original builder, return self instead + return self if result == self._builder else result + + return wrapped + return attr + + __doc__ = property(lambda self: self._spark_session.__doc__) + + +class SparkSession: + builder = Builder() + + def __init__(self, spark_session, connection=None): + self._spark_session = spark_session + self._connection = connection + + def __repr__(self): + return self._spark_session.__repr__() + + def __getattr__(self, name): + return getattr(self._spark_session, name) + + def stop(self): + self._spark_session.stop() + self._connection.shutdown() + + __doc__ = property(lambda self: self._spark_session.__doc__) diff --git a/tests/connect/conftest.py b/tests/connect/conftest.py index ab7e36777b..0c24ec6f3d 100644 --- a/tests/connect/conftest.py +++ b/tests/connect/conftest.py @@ -1,28 +1,39 @@ from __future__ import annotations import pytest -from pyspark.sql import SparkSession + + +@pytest.fixture(params=["local_spark", "ray_spark"], scope="session") +def spark_session(request): + return request.getfixturevalue(request.param) @pytest.fixture(scope="session") -def spark_session(): +def local_spark(): """Fixture to create and clean up a Spark session. This fixture is available to all test files and creates a single Spark session for the entire test suite run. """ - from daft.daft import connect_start + from daft.pyspark import SparkSession + + # Initialize Spark Connect session + session = SparkSession.builder.appName("DaftConfigTest").local().getOrCreate() + + yield session + - # Start Daft Connect server - server = connect_start() +@pytest.fixture(scope="session") +def ray_spark(): + """Fixture to create and clean up a Spark session. - url = f"sc://localhost:{server.port()}" + This fixture is available to all test files and creates a single + Spark session for the entire test suite run. + """ + from daft.pyspark import SparkSession # Initialize Spark Connect session - session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate() - yield session + session = SparkSession.builder.appName("DaftConfigTest").remote("ray://localhost:10001").getOrCreate() - # Cleanup - server.shutdown() - session.stop() + yield session