Skip to content

Commit

Permalink
fix: make baml_py work with playwright/inspect (#1214)
Browse files Browse the repository at this point in the history
When using `baml_py.Image` in a Pydantic model with playwright-python
(the headless browser stack), the playwright library freezes up (see
[user
report](https://discord.com/channels/1119368998161752075/1309277325703250000)).
Specifically this code suffices to repro the issue:

```
from playwright.sync_api import sync_playwright
import baml_py
from pydantic import BaseModel

class Foo(BaseModel):
    screenshot: baml_py.Image

print('This happens')
sync_playwright().start()
print('This never happens')
```

The reason turns out to be that when the playwright context manager is
entered, deep in the callstack, the playwright connection grabs the
stack context using `inspect.stack()`. Because of how
`__get_pydantic_core_schema__` was implemented for `baml_py.Image` and
`baml_py.Audio`, declaring a pydantic model that relied on either of
these types would cause `inspect.stack()` to crash:

```
============================================================================= FAILURES ==============================================================================
___________________________________________________________________________ test_inspect ____________________________________________________________________________

    def test_inspect():
        class LoremIpsum(pydantic.BaseModel):
            my_image: baml_py.Image

>       inspect.stack()

tests/test_pydantic.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:1673: in stack
    return getouterframes(sys._getframe(1), context)
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:1650: in getouterframes
    frameinfo = (frame,) + getframeinfo(frame, context)
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:1624: in getframeinfo
    lines, lnum = findsource(frame)
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:952: in findsource
    module = getmodule(object, file)
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:875: in getmodule
    f = getabsfile(module)
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:844: in getabsfile
    _filename = getsourcefile(object) or getfile(object)
../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:817: in getsourcefile
    filename = getfile(object)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

object = <module 'rust:media_repr' from '/Users/sam/baml/integ-tests/python'>

    def getfile(object):
        """Work out which source or compiled file an object was defined in."""
        if ismodule(object):
            if getattr(object, '__file__', None):
                return object.__file__
>           raise TypeError('{!r} is a built-in module'.format(object))
E           TypeError: <module 'rust:media_repr' from '/Users/sam/baml/integ-tests/python'> is a built-in module

../../../.local/share/mise/installs/python/3.10.14/lib/python3.10/inspect.py:778: TypeError
====================================================================== short test summary info ======================================================================
FAILED tests/test_pydantic.py::test_inspect - TypeError: <module 'rust:media_repr' from '/Users/sam/baml/integ-tests/python'> is a built-in module
```

The fix turns out to be very simple: when we evaluate Python code in
`__get_pydantic_core_schema__`, we just need to actually synthesize a
file and module name for the evaluated code, instead of passing in an
empty string (which is what we do today).

(Also, since we added pickle support for image/audio during the
investigation, keep it.)
<!-- ELLIPSIS_HIDDEN -->

----

> [!IMPORTANT]
> Fixes `inspect.stack()` crash and adds pickle support for
`baml_py.Image` and `baml_py.Audio`, with updated tests and CI workflow.
> 
>   - **Behavior**:
> - Fixes `inspect.stack()` crash by providing file and module name in
`__get_pydantic_core_schema__` for `baml_py.Image` and `baml_py.Audio`.
> - Adds pickle support for `baml_py.Image` and `baml_py.Audio` with
`py_new` and `__getnewargs__` methods.
>   - **Testing**:
> - Adds `test_inspect` and `test_pickle` in `test_python.py` to verify
compatibility and pickle functionality.
>     - Introduces `run_tests.sh` for running Python integration tests.
>   - **CI/CD**:
> - Updates `primary.yml` to include a scheduled job and integration
tests job.
>     - Adds `poetry` version to `.mise.toml`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for bb55aa2. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
sxlijin authored Dec 5, 2024
1 parent e107900 commit 6741999
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 14 deletions.
20 changes: 20 additions & 0 deletions .github/workflows/primary.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ on:
- ".github/workflows/primary.yml"
branches:
- canary
# need to run this periodically on the default branch to populate the build cache
schedule:
# daily at 2am PST
- cron: 0 10 * * *
merge_group:
types: [checks_requested]
workflow_dispatch: {}
Expand Down Expand Up @@ -106,3 +110,19 @@ jobs:
- name: Build rust for wasm32
run: cargo build --target=wasm32-unknown-unknown
working-directory: engine/baml-schema-wasm
integ-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: jdx/mise-action@v2
- uses: dtolnay/rust-toolchain@stable
with:
toolchain: stable
- uses: Swatinem/rust-cache@v2
with:
workspaces: engine
- name: run python tests
run: |
cd integ-tests/python
poetry install
./run_tests.sh
1 change: 1 addition & 0 deletions .mise.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
node = "20.14"
ruby = "3.1"
pnpm = "9.9"
poetry = "1.8.4"
6 changes: 5 additions & 1 deletion engine/language_client_python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ regex.workspace = true
serde.workspace = true
serde_json.workspace = true
tokio = { version = "1", features = ["full"] }
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter","valuable"] }
tracing-subscriber = { version = "0.3.18", features = [
"json",
"env-filter",
"valuable",
] }

[build-dependencies]
pyo3-build-config = "0.21.2"
7 changes: 2 additions & 5 deletions engine/language_client_python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ fn invoke_runtime_cli(py: Python) -> PyResult<()> {
.map_err(errors::BamlError::from_anyhow)
}

pub(crate) const MODULE_NAME: &str = "baml_py.baml_py";

#[pymodule]
fn baml_py(m: Bound<'_, PyModule>) -> PyResult<()> {
let use_json = match std::env::var("BAML_LOG_JSON") {
Expand Down Expand Up @@ -74,11 +76,6 @@ fn baml_py(m: Bound<'_, PyModule>) -> PyResult<()> {

m.add_wrapped(wrap_pyfunction!(invoke_runtime_cli))?;

// m.add(
// "BamlValidationError",
// m.py().get_type_bound::<errors::BamlValidationError>(),
// )?;
// m.add_class::<errors::BamlValidationError>()?;
errors::errors(&m)?;

Ok(())
Expand Down
16 changes: 15 additions & 1 deletion engine/language_client_python/src/types/audio.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use baml_types::BamlMediaContent;
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::PyType;
use pyo3::types::{PyTuple, PyType};
use pyo3::{Bound, PyAny, PyObject, Python};
use pythonize::{depythonize_bound, pythonize};

Expand Down Expand Up @@ -50,6 +50,20 @@ impl BamlAudioPy {
}
}

/// Defines the default constructor: https://pyo3.rs/v0.23.3/class#constructor
///
/// Used for `pickle.load`: https://docs.python.org/3/library/pickle.html#object.__getnewargs__
#[new]
pub fn py_new(data: PyObject, py: Python<'_>) -> PyResult<Self> {
Self::baml_deserialize(data, py)
}

/// Used for `pickle.dump`: https://docs.python.org/3/library/pickle.html#object.__getnewargs__
pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
let o = self.baml_serialize(py)?;
Ok(PyTuple::new_bound(py, vec![o]))
}

pub fn __repr__(&self) -> String {
match &self.inner.content {
BamlMediaContent::Url(url) => {
Expand Down
16 changes: 15 additions & 1 deletion engine/language_client_python/src/types/image.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::PyType;
use pyo3::types::{PyTuple, PyType};
use pyo3::{Bound, PyAny, PyObject, Python};
use pythonize::{depythonize_bound, pythonize};

Expand Down Expand Up @@ -49,6 +49,20 @@ impl BamlImagePy {
}
}

/// Defines the default constructor: https://pyo3.rs/v0.23.3/class#constructor
///
/// Used for `pickle.load`: https://docs.python.org/3/library/pickle.html#object.__getnewargs__
#[new]
pub fn py_new(data: PyObject, py: Python<'_>) -> PyResult<Self> {
Self::baml_deserialize(data, py)
}

/// Used for `pickle.dump`: https://docs.python.org/3/library/pickle.html#object.__getnewargs__
pub fn __getnewargs__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
let o = self.baml_serialize(py)?;
Ok(PyTuple::new_bound(py, vec![o]))
}

pub fn __repr__(&self) -> String {
match &self.inner.content {
baml_types::BamlMediaContent::Url(url) => {
Expand Down
10 changes: 5 additions & 5 deletions engine/language_client_python/src/types/lang_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[macro_export]
macro_rules! lang_wrapper {
($name:ident, $type:ty, clone_safe $(, $attr_name:ident : $attr_type:ty = $default:expr)*) => {
#[pyo3::prelude::pyclass]
#[pyo3::prelude::pyclass(module = "baml_py.baml_py")]
pub struct $name {
pub(crate) inner: std::sync::Arc<$type>,
$($attr_name: $attr_type),*
Expand All @@ -18,7 +18,7 @@ macro_rules! lang_wrapper {
};

($name:ident, $type:ty, thread_safe $(, $attr_name:ident : $attr_type:ty)*) => {
#[pyo3::prelude::pyclass]
#[pyo3::prelude::pyclass(module = "baml_py.baml_py")]
pub struct $name {
pub(crate) inner: std::sync::Arc<tokio::sync::Mutex<$type>>,
$($attr_name: $attr_type),*
Expand All @@ -35,7 +35,7 @@ macro_rules! lang_wrapper {
};

($name:ident, $type:ty, sync_thread_safe $(, $attr_name:ident : $attr_type:ty)*) => {
#[pyo3::prelude::pyclass]
#[pyo3::prelude::pyclass(module = "baml_py.baml_py")]
pub struct $name {
pub(crate) inner: std::sync::Arc<std::sync::Mutex<$type>>,
$($attr_name: $attr_type),*
Expand All @@ -62,7 +62,7 @@ macro_rules! lang_wrapper {
};

($name:ident, $type:ty $(, $attr_name:ident : $attr_type:ty = $default:expr)*) => {
#[pyo3::prelude::pyclass]
#[pyo3::prelude::pyclass(module = "baml_py.baml_py")]
pub struct $name {
pub(crate) inner: $type,
$($attr_name: $attr_type),*
Expand All @@ -79,7 +79,7 @@ macro_rules! lang_wrapper {
};

($name:ident, $type:ty, no_from $(, $attr_name:ident : $attr_type:ty)*) => {
#[pyo3::prelude::pyclass]
#[pyo3::prelude::pyclass(module = "baml_py.baml_py")]
pub struct $name {
pub(crate) inner: $type,
$($attr_name: $attr_type),*
Expand Down
5 changes: 4 additions & 1 deletion engine/language_client_python/src/types/media_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ impl TryInto<UserFacingBamlMedia> for &BamlMedia {
/// can't implement this in internal_monkeypatch without adding a hard dependency
/// on pydantic. And we don't want to do _that_, because that will make it harder
/// to implement output_type python/vanilla in the future.
///
/// See docs:
/// https://docs.pydantic.dev/latest/concepts/types/#customizing-validation-with-__get_pydantic_core_schema__
pub fn __get_pydantic_core_schema__(
_cls: Bound<'_, PyType>,
_source_type: Bound<'_, PyAny>,
Expand Down Expand Up @@ -129,7 +132,7 @@ def get_schema():
ret = get_schema()
"#;
// py.run(code, None, Some(ret_dict));
let fun: Py<PyAny> = PyModule::from_code_bound(py, code, "", "")?
let fun: Py<PyAny> = PyModule::from_code_bound(py, code, file!(), crate::MODULE_NAME)?
.getattr("ret")?
.into();
Ok(fun.to_object(py))
Expand Down
11 changes: 11 additions & 0 deletions integ-tests/python/run_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

# Run tests for CI

set -euxo pipefail

env -u CONDA_PREFIX poetry run maturin develop --manifest-path ../../engine/language_client_python/Cargo.toml
poetry run baml-cli generate --from ../baml_src

# test_functions.py is excluded because it requires credentials
poetry run pytest "$@" --ignore=tests/test_functions.py
44 changes: 44 additions & 0 deletions integ-tests/python/tests/test_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Test the compatibility of baml_py with the Python ecosystem."""

import baml_py
import inspect
import pickle
import pydantic
import pytest


def test_inspect():
"""Assert that baml_py is compatible with the inspect module.
This is a regression test for a bug where `inspect.stack()` would implode if the
pyo3 code called `PyModule::from_code` without specifying the `file_name` arg (i.e.
without specifying the source file metadata for the inline Python snippet).
"""

class LoremIpsum(pydantic.BaseModel): # pyright: ignore[reportUnusedClass]
"""Defining this Pydantic model alone is sufficient to trigger the bug."""

my_image: baml_py.Image
my_audio: baml_py.Audio

try:
inspect.stack()
except Exception as e:
pytest.fail(f"inspect.stack() raised an unexpected exception: {e}")


def test_pickle():
i = baml_py.Image.from_url("https://example.com/image.png")
p = pickle.dumps(i)
assert i == pickle.loads(pickle.dumps(i))
assert p == pickle.dumps(pickle.loads(p))

i2 = baml_py.Image.from_url("https://example.com/image.jpg")
p2 = pickle.dumps(i2)
assert i2 == pickle.loads(pickle.dumps(i2))
assert p2 == pickle.dumps(pickle.loads(p2))

i3 = baml_py.Image.from_base64("image/png", "iVBORw0KGgoAAAANSUhEUgAAAAUA")
p3 = pickle.dumps(i3)
assert i3 == pickle.loads(pickle.dumps(i3))
assert p3 == pickle.dumps(pickle.loads(p3))

0 comments on commit 6741999

Please sign in to comment.