Skip to content

Commit

Permalink
feat(connect): daft.pyspark module (#3861)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Feb 26, 2025
1 parent c3591a2 commit d468589
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 11 deletions.
94 changes: 94 additions & 0 deletions daft/pyspark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""# PySpark.
The `daft.pyspark` module provides a way to create a PySpark session that can be run locally or backed by a ray cluster.
This serves as a way to run the daft query engine, but with a spark compatible API.
## Example
```py
from daft.pyspark import SparkSession
from pyspark.sql.functions import col
# create a local spark session
spark = SparkSession.builder.local().getOrCreate()
# alternatively, connect to a ray cluster
spark = SparkSession.builder.remote("ray://<ray-ip>:10001").getOrCreate()
# use spark as you would with the native spark library, but with a daft backend!
spark.createDataFrame([{"hello": "world"}]).select(col("hello")).show()
# stop the spark session
spark.stop()
```
"""

from daft.daft import connect_start
from pyspark.sql import SparkSession as PySparkSession


class Builder:
def __init__(self):
self._builder = PySparkSession.builder

def local(self):
self._connection = connect_start()
url = f"sc://0.0.0.0:{self._connection.port()}"
self._builder = PySparkSession.builder.remote(url)
return self

def remote(self, url):
if url.startswith("ray://"):
import daft

if url.startswith("ray://localhost"):
daft.context.set_runner_ray(noop_if_initialized=True)
else:
daft.context.set_runner_ray(address=url, noop_if_initialized=True)
self._connection = connect_start()
url = f"sc://0.0.0.0:{self._connection.port()}"
self._builder = PySparkSession.builder.remote(url)
return self
else:
self._builder = PySparkSession.builder.remote(url)
return self

def getOrCreate(self):
return SparkSession(self._builder.getOrCreate(), self._connection)

def __getattr__(self, name):
attr = getattr(self._builder, name)
if callable(attr):

def wrapped(*args, **kwargs):
result = attr(*args, **kwargs)
# If result is the original builder, return self instead
return self if result == self._builder else result

return wrapped
return attr

__doc__ = property(lambda self: self._spark_session.__doc__)


class SparkSession:
builder = Builder()

def __init__(self, spark_session, connection=None):
self._spark_session = spark_session
self._connection = connection

def __repr__(self):
return self._spark_session.__repr__()

def __getattr__(self, name):
return getattr(self._spark_session, name)

def stop(self):
self._spark_session.stop()
self._connection.shutdown()

__doc__ = property(lambda self: self._spark_session.__doc__)
33 changes: 22 additions & 11 deletions tests/connect/conftest.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
from __future__ import annotations

import pytest
from pyspark.sql import SparkSession


@pytest.fixture(params=["local_spark", "ray_spark"], scope="session")
def spark_session(request):
return request.getfixturevalue(request.param)


@pytest.fixture(scope="session")
def spark_session():
def local_spark():
"""Fixture to create and clean up a Spark session.
This fixture is available to all test files and creates a single
Spark session for the entire test suite run.
"""
from daft.daft import connect_start
from daft.pyspark import SparkSession

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").local().getOrCreate()

yield session


# Start Daft Connect server
server = connect_start()
@pytest.fixture(scope="session")
def ray_spark():
"""Fixture to create and clean up a Spark session.
url = f"sc://localhost:{server.port()}"
This fixture is available to all test files and creates a single
Spark session for the entire test suite run.
"""
from daft.pyspark import SparkSession

# Initialize Spark Connect session
session = SparkSession.builder.appName("DaftConfigTest").remote(url).getOrCreate()

yield session
session = SparkSession.builder.appName("DaftConfigTest").remote("ray://localhost:10001").getOrCreate()

# Cleanup
server.shutdown()
session.stop()
yield session

0 comments on commit d468589

Please sign in to comment.