Skip to content

Commit

Permalink
update serialization exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandra Belousov authored and Alexandra Belousov committed Mar 5, 2025
1 parent b66b680 commit 5d7d0f0
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 8 deletions.
2 changes: 1 addition & 1 deletion runhouse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import runhouse.resources.images.builtin_images as images

from runhouse.exceptions import InsufficientDiskError
from runhouse.exceptions import InsufficientDiskError, SerializationError
from runhouse.resources.asgi import Asgi, asgi
from runhouse.resources.folders import Folder, folder, GCSFolder, S3Folder
from runhouse.resources.functions.function import Function
Expand Down
18 changes: 18 additions & 0 deletions runhouse/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,21 @@ def __init__(
)
msg = f"{msg}. To resolve it, teardown the cluster and re-launch it with larger disk size."
super().__init__(msg)


class SerializationError(Exception):
"""Raised when we have serialization error.
Args:
error_msg: The error message to print.
"""

def __init__(
self,
error_msg: str = None,
) -> None:
self.error_msg = error_msg
self.default_error_msg = "Got a serialization error."
msg = self.error_msg if self.error_msg else self.default_error_msg
msg = f"{msg}. Make sure that the remote and local versions of python and all installed packages are as expected.\n Please Check logs for more information."
super().__init__(msg)
36 changes: 36 additions & 0 deletions runhouse/resources/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,43 @@
import builtins
import os
import sys
import traceback
from pathlib import Path


def extract_error_from_ray_result(result_obj):
error = result_obj.error
error_msg = error.__str__()
exception = error_msg.split("\n")[-1].strip().split(": ")
exception_type = exception[0]
exception_msg = exception[1] if len(exception) > 1 else ""

exception_class = Exception

# Try to find more the exact exception that is should be raised
if hasattr(builtins, exception_type):
exception_class = getattr(builtins, exception_type)
else:
# Try to find the exception class in common modules
for module_name in ["exceptions", "os", "io", "socket", "ray.exceptions"]:
try:
module = sys.modules.get(module_name) or __import__(module_name)
if hasattr(module, exception_type):
exception_class = getattr(module, exception_type)
break
# ImportError, AttributeError are part of the builtin methods.
except (ImportError, AttributeError):
continue

# Create the exception instance with the original message
exception_instance = exception_class(exception_msg)

# Optionally add the original traceback as a note (Python 3.11+)
if hasattr(exception_instance, "__notes__"):
exception_instance.__notes__ = traceback.format_exception(error)
raise exception_instance


def subprocess_ray_fn_call_helper(pointers, args, kwargs, conn, ray_opts={}):
def write_stdout(msg):
conn.send((msg, "stdout"))
Expand Down Expand Up @@ -37,6 +72,7 @@ def write_stderr(msg):
)
try:
res = orig_fn(*args, **kwargs)
res = extract_error_from_ray_result(res) if hasattr(res, "error") else res
return res
finally:
ray.shutdown()
Expand Down
38 changes: 31 additions & 7 deletions runhouse/servers/http/http_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import codecs
import builtins
import json
import re
import shutil
Expand All @@ -10,7 +10,7 @@

from pydantic import BaseModel, field_validator

from runhouse.exceptions import InsufficientDiskError
from runhouse.exceptions import InsufficientDiskError, SerializationError

from runhouse.logger import get_logger

Expand Down Expand Up @@ -170,15 +170,25 @@ def convert_path_to_string(cls, v):


def pickle_b64(picklable):
import codecs

import cloudpickle

return codecs.encode(cloudpickle.dumps(picklable), "base64").decode()
try:
return codecs.encode(cloudpickle.dumps(picklable), "base64").decode()
except Exception as e:
raise SerializationError(error_msg=e.__str__())


def b64_unpickle(b64_pickled):
import codecs

import cloudpickle

return cloudpickle.loads(codecs.decode(b64_pickled.encode(), "base64"))
try:
return cloudpickle.loads(codecs.decode(b64_pickled.encode(), "base64"))
except Exception as e:
raise SerializationError(error_msg=e.__str__())


def deserialize_data(data: Any, serialization: Optional[str]):
Expand Down Expand Up @@ -339,26 +349,40 @@ def handle_response(
except Exception as e:
logger.error(
f"{system_color}{err_str}: Failed to unpickle exception. Please check the logs for more "
f"information.{reset_color}"
f"information.\n Make sure that the remote and local versions of python and all installed packages are as expected.{reset_color}"
)
if fn_exception_as_str:
logger.error(
f"{system_color}{err_str} Exception as string: {fn_exception_as_str}{reset_color}"
)
logger.error(f"{system_color}Traceback: {fn_traceback}{reset_color}")
raise e
if isinstance(e, SerializationError):
raise e

raise SerializationError(error_msg=err_str)

is_builtins_exception, fn_exception_msg = None, ""

if not (
isinstance(fn_exception, StopIteration)
or isinstance(fn_exception, GeneratorExit)
or isinstance(fn_exception, StopAsyncIteration)
):
logger.error(f"{system_color}{err_str}: {fn_exception}{reset_color}")
is_builtins_exception = fn_exception.get("module") == "builtins"
fn_exception_msg = (
fn_exception.get("args")[0] if is_builtins_exception else fn_exception
)
logger.error(f"{system_color}{err_str}: {fn_exception_msg}{reset_color}")
logger.error(f"{system_color}Traceback: {fn_traceback}{reset_color}")

# Errno 28 means "No space left on device"
if isinstance(fn_exception, OSError) and fn_exception.errno == 28:
raise InsufficientDiskError()

if is_builtins_exception:
exception_class = getattr(builtins, fn_exception.get("class_name"), None)
if exception_class:
raise exception_class(fn_exception_msg)
raise fn_exception
elif output_type == OutputType.STDOUT:
res = response_data["data"]
Expand Down

0 comments on commit 5d7d0f0

Please sign in to comment.