Skip to content

Commit

Permalink
[SPARK-50858][PYTHON] Add configuration to hide Python UDF stack trace
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds new configuration `spark.sql.execution.pyspark.udf.hideTraceback.enabled`. If set, when handling an exception from Python UDF, only the exception class and message are included. The configuration is turned off by default.
This PR also adds a new optional parameter `hide_traceback` for `handle_udf_exception` to override the configuration.

Suggested review order:
1. `python/pyspark/util.py`: logic changes
2. `python/pyspark/tests/test_util.py`: unit tests
3. other files: adding new configuration

### Why are the changes needed?

This allows library provided UDFs to show only the relevant message without unnecessary stack trace.

### Does this PR introduce _any_ user-facing change?

If the configuration is turned off, no user change.
Otherwise, the stack trace is not included in the error message when handling an exception from Python UDF.

<details>
<summary>Example that illustrates the difference</summary>

```py
from pyspark.errors.exceptions.base import PySparkRuntimeError
from pyspark.sql.types import IntegerType, StructField, StructType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
from pyspark.sql.functions import udtf

udtf()
class PythonUDTF:
    staticmethod
    def analyze(x: AnalyzeArgument) -> AnalyzeResult:
        raise PySparkRuntimeError("[XXX] My PySpark runtime error.")

    def eval(self, x: int):
        yield (x,)

spark.udtf.register("my_udtf", PythonUDTF)
spark.sql("select * from my_udtf(1)").show()
```

With configuration turned off, the last line gives:
```
...
pyspark.errors.exceptions.captured.AnalysisException: [TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON] Failed to analyze the Python user defined table function: Traceback (most recent call last):
  File "<stdin>", line 7, in analyze
pyspark.errors.exceptions.base.PySparkRuntimeError: [XXX] My PySpark runtime error. SQLSTATE: 38000; line 1 pos 14
```

With configuration turned on, the last line gives:
```
...
pyspark.errors.exceptions.captured.AnalysisException: [TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON] Failed to analyze the Python user defined table function: pyspark.errors.exceptions.base.PySparkRuntimeError: [XXX] My PySpark runtime error. SQLSTATE: 38000; line 1 pos 14
```

</details>

### How was this patch tested?

Added unit test in `python/pyspark/tests/test_util.py`, testing two cases with the configuration turned on and off respectively.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #49535 from wengh/spark-50858-hide-udf-stack-trace.

Authored-by: Haoyu Weng <[email protected]>
Signed-off-by: Allison Wang <[email protected]>
(cherry picked from commit d259132)
Signed-off-by: Allison Wang <[email protected]>
  • Loading branch information
wengh authored and allisonwang-db committed Jan 28, 2025
1 parent 83d5d44 commit 40f6b3f
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false

// All the Python functions should have the same exec, version and envvars.
Expand Down Expand Up @@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (hideTraceback) {
envVars.put("SPARK_HIDE_TRACEBACK", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
Expand Down
41 changes: 41 additions & 0 deletions python/pyspark/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
import os
import unittest
from unittest.mock import patch

from py4j.protocol import Py4JJavaError

Expand Down Expand Up @@ -125,6 +126,46 @@ def test_parse_memory(self):
_parse_memory("2gs")


class HandleWorkerExceptionTests(unittest.TestCase):
exception_bytes = b"ValueError: test_message"
traceback_bytes = b"Traceback (most recent call last):"

def run_handle_worker_exception(self, hide_traceback=None):
import io
from pyspark.util import handle_worker_exception

try:
raise ValueError("test_message")
except Exception as e:
with io.BytesIO() as stream:
handle_worker_exception(e, stream, hide_traceback)
return stream.getvalue()

@patch.dict(os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "", "SPARK_HIDE_TRACEBACK": ""})
def test_env_full(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)

@patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
def test_env_hide_traceback(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertNotIn(self.traceback_bytes, result)

@patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
def test_full(self):
# Should ignore the environment variable because hide_traceback is explicitly set.
result = self.run_handle_worker_exception(False)
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)

def test_hide_traceback(self):
result = self.run_handle_worker_exception(True)
self.assertIn(self.exception_bytes, result)
self.assertNotIn(self.traceback_bytes, result)


if __name__ == "__main__":
from pyspark.tests.test_util import * # noqa: F401

Expand Down
30 changes: 24 additions & 6 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,22 +462,40 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return f # type: ignore[return-value]


def handle_worker_exception(e: BaseException, outfile: IO) -> None:
def handle_worker_exception(
e: BaseException, outfile: IO, hide_traceback: Optional[bool] = None
) -> None:
"""
Handles exception for Python worker which writes SpecialLengths.PYTHON_EXCEPTION_THROWN (-2)
and exception traceback info to outfile. JVM could then read from the outfile and perform
exception handling there.
Parameters
----------
e : BaseException
Exception handled
outfile : IO
IO object to write the exception info
hide_traceback : bool, optional
Whether to hide the traceback in the output.
By default, hides the traceback if environment variable SPARK_HIDE_TRACEBACK is set.
"""
try:
exc_info = None

if hide_traceback is None:
hide_traceback = bool(os.environ.get("SPARK_HIDE_TRACEBACK", False))

def format_exception() -> str:
if hide_traceback:
return "".join(traceback.format_exception_only(type(e), e))
if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type]
if tb is not None:
e.__cause__ = None
exc_info = "".join(traceback.format_exception(type(e), e, tb))
if exc_info is None:
exc_info = traceback.format_exc()
return "".join(traceback.format_exception(type(e), e, tb))
return traceback.format_exc()

try:
exc_info = format_exception()
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(exc_info.encode("utf-8"), outfile)
except IOError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3476,6 +3476,15 @@ object SQLConf {
.checkValues(Set("legacy", "row", "dict"))
.createWithDefaultString("legacy")

val PYSPARK_HIDE_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled")
.doc(
"When true, only show the message of the exception from Python UDFs, " +
"hiding the stack trace. If this is enabled, simplifiedTraceback has no effect.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val PYSPARK_SIMPLIFIED_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled")
.doc(
Expand Down Expand Up @@ -6287,6 +6296,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)

def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK)

def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)

def pandasGroupedMapAssignColumnsByName: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner(
override protected lazy val timeZoneId: String = _timeZoneId
override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback

override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner(

override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

// Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner(

override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner(

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

protected def newWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType)

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory

Expand All @@ -68,6 +69,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (hideTraceback) {
envVars.put("SPARK_HIDE_TRACEBACK", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner(
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head._1.funcs.head.pythonExec)

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
Expand Down

0 comments on commit 40f6b3f

Please sign in to comment.