-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(connect):
daft.pyspark
module (#3861)
- Loading branch information
1 parent
c3591a2
commit d468589
Showing
2 changed files
with
116 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""# PySpark. | ||
The `daft.pyspark` module provides a way to create a PySpark session that can be run locally or backed by a ray cluster. | ||
This serves as a way to run the daft query engine, but with a spark compatible API. | ||
## Example | ||
```py | ||
from daft.pyspark import SparkSession | ||
from pyspark.sql.functions import col | ||
# create a local spark session | ||
spark = SparkSession.builder.local().getOrCreate() | ||
# alternatively, connect to a ray cluster | ||
spark = SparkSession.builder.remote("ray://<ray-ip>:10001").getOrCreate() | ||
# use spark as you would with the native spark library, but with a daft backend! | ||
spark.createDataFrame([{"hello": "world"}]).select(col("hello")).show() | ||
# stop the spark session | ||
spark.stop() | ||
``` | ||
""" | ||
|
||
from daft.daft import connect_start | ||
from pyspark.sql import SparkSession as PySparkSession | ||
|
||
|
||
class Builder: | ||
def __init__(self): | ||
self._builder = PySparkSession.builder | ||
|
||
def local(self): | ||
self._connection = connect_start() | ||
url = f"sc://0.0.0.0:{self._connection.port()}" | ||
self._builder = PySparkSession.builder.remote(url) | ||
return self | ||
|
||
def remote(self, url): | ||
if url.startswith("ray://"): | ||
import daft | ||
|
||
if url.startswith("ray://localhost"): | ||
daft.context.set_runner_ray(noop_if_initialized=True) | ||
else: | ||
daft.context.set_runner_ray(address=url, noop_if_initialized=True) | ||
self._connection = connect_start() | ||
url = f"sc://0.0.0.0:{self._connection.port()}" | ||
self._builder = PySparkSession.builder.remote(url) | ||
return self | ||
else: | ||
self._builder = PySparkSession.builder.remote(url) | ||
return self | ||
|
||
def getOrCreate(self): | ||
return SparkSession(self._builder.getOrCreate(), self._connection) | ||
|
||
def __getattr__(self, name): | ||
attr = getattr(self._builder, name) | ||
if callable(attr): | ||
|
||
def wrapped(*args, **kwargs): | ||
result = attr(*args, **kwargs) | ||
# If result is the original builder, return self instead | ||
return self if result == self._builder else result | ||
|
||
return wrapped | ||
return attr | ||
|
||
__doc__ = property(lambda self: self._spark_session.__doc__) | ||
|
||
|
||
class SparkSession: | ||
builder = Builder() | ||
|
||
def __init__(self, spark_session, connection=None): | ||
self._spark_session = spark_session | ||
self._connection = connection | ||
|
||
def __repr__(self): | ||
return self._spark_session.__repr__() | ||
|
||
def __getattr__(self, name): | ||
return getattr(self._spark_session, name) | ||
|
||
def stop(self): | ||
self._spark_session.stop() | ||
self._connection.shutdown() | ||
|
||
__doc__ = property(lambda self: self._spark_session.__doc__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,39 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
from pyspark.sql import SparkSession | ||
|
||
|
||
@pytest.fixture(params=["local_spark", "ray_spark"], scope="session") | ||
def spark_session(request): | ||
return request.getfixturevalue(request.param) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def spark_session(): | ||
def local_spark(): | ||
"""Fixture to create and clean up a Spark session. | ||
This fixture is available to all test files and creates a single | ||
Spark session for the entire test suite run. | ||
""" | ||
from daft.daft import connect_start | ||
from daft.pyspark import SparkSession | ||
|
||
# Initialize Spark Connect session | ||
session = SparkSession.builder.appName("DaftConfigTest").local().getOrCreate() | ||
|
||
yield session | ||
|
||
|
||
# Start Daft Connect server | ||
server = connect_start() | ||
@pytest.fixture(scope="session") | ||
def ray_spark(): | ||
"""Fixture to create and clean up a Spark session. | ||
url = f"sc://localhost:{server.port()}" | ||
This fixture is available to all test files and creates a single | ||
Spark session for the entire test suite run. | ||
""" | ||
from daft.pyspark import SparkSession | ||
|
||
# Initialize Spark Connect session | ||
session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate() | ||
|
||
yield session | ||
session = SparkSession.builder.appName("DaftConfigTest").remote("ray://localhost:10001").getOrCreate() | ||
|
||
# Cleanup | ||
server.shutdown() | ||
session.stop() | ||
yield session |