Skip to content

Commit

Permalink
chore: upgrade pyo3 (#1215)
Browse files Browse the repository at this point in the history
Since I have pyo3 paged into memory right now, figured that I would just
take the plunge, because I keep getting confused every single time I
look up pyo3 docs and see "bound" missing.
<!-- ELLIPSIS_HIDDEN -->


----

> [!IMPORTANT]
> Upgrade `pyo3` to `0.23.3` and adjust codebase for compatibility,
including dependency updates, code changes, and new integration tests.
> 
>   - **Dependencies**:
>     - Upgrade `pyo3` to `0.23.3` in `Cargo.toml` and `Cargo.lock`.
>     - Replace `pyo3-asyncio` with `pyo3-async-runtimes`.
>     - Update `pythonize` to `0.23`.
>   - **Code Adjustments**:
> - Replace `import_bound` with `import` in `errors.rs`, `lib.rs`,
`parse_py_type.rs`.
> - Use `into_py_any` instead of `to_object` for PyO3 conversions in
`parse_py_type.rs`, `runtime.rs`.
> - Update `lang_wrapper!` macro to include `module` attribute in
`lang_wrapper.rs`.
>   - **Error Handling**:
> - Modify exception handling in `errors.rs` to use updated PyO3
methods.
>   - **Integration Tests**:
>     - Add `run_tests.sh` for running Python integration tests.
> - Add `test_python.py` to test `inspect` and `pickle` compatibility
with `baml_py`.
>   - **Miscellaneous**:
> - Add `integ-tests` job to GitHub Actions workflow in `primary.yml`.
>     - Add `poetry` version `1.8.4` to `.mise.toml`.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup>
for afbe94a. It will automatically
update as commits are pushed.</sup>


<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
sxlijin authored Dec 5, 2024
1 parent 6741999 commit 44b57f7
Show file tree
Hide file tree
Showing 16 changed files with 167 additions and 129 deletions.
67 changes: 45 additions & 22 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 5 additions & 7 deletions engine/language_client_python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ env_logger.workspace = true
futures.workspace = true
indexmap.workspace = true
log.workspace = true
pyo3 = { version = "0.21.2", default-features = false, features = [
# Consult https://pyo3.rs/main/migration for migration instructions
pyo3 = { version = "0.23.3", default-features = false, features = [
"abi3-py38",
"extension-module",
"generate-import-lib",
"serde",
] }
# pyo3-asyncio head is still on 0.20.0; someone's done the work of updating it to 0.21 and Bound<>,
# but that work hasn't been merged yet. it builds though, and looks good to me!
# https://github.com/awestlake87/pyo3-asyncio/pull/121
pyo3-asyncio = { git = "https://github.com/BoundaryML/pyo3-asyncio.git", branch = "migration-pyo3-0.21", features = [
pyo3-async-runtimes = { version = "0.23", features = [
"attributes",
"tokio-runtime",
] }
#pyo3-asyncio = { version = "0.20.0", features = ["tokio-runtime"] }
pythonize = "0.21.1"
pythonize = "0.23"
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
Expand Down
39 changes: 22 additions & 17 deletions engine/language_client_python/src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use baml_runtime::{
errors::ExposedError, internal::llm_client::LLMResponse, scope_diagnostics::ScopeStack,
};
use pyo3::types::{PyAnyMethods, PyModule};
use pyo3::types::{PyAnyMethods, PyModule, PyModuleMethods};
use pyo3::{create_exception, pymodule, Bound, PyErr, PyResult, Python};

create_exception!(baml_py, BamlError, pyo3::exceptions::PyException);
Expand All @@ -17,19 +17,26 @@ create_exception!(baml_py, BamlClientHttpError, BamlClientError);
#[allow(non_snake_case)]
fn raise_baml_validation_error(prompt: String, message: String, raw_output: String) -> PyErr {
Python::with_gil(|py| {
let internal_monkeypatch = py.import_bound("baml_py.internal_monkeypatch").unwrap();
let internal_monkeypatch = py.import("baml_py.internal_monkeypatch").unwrap();
let exception = internal_monkeypatch.getattr("BamlValidationError").unwrap();
let args = (prompt, message, raw_output);
let inst = exception.call1(args).unwrap();
PyErr::from_value_bound(inst)
PyErr::from_value(inst)
})
}

#[allow(non_snake_case)]
fn raise_baml_client_finish_reason_error(prompt: String, raw_output: String, message: String, finish_reason: Option<String>) -> PyErr {
fn raise_baml_client_finish_reason_error(
prompt: String,
raw_output: String,
message: String,
finish_reason: Option<String>,
) -> PyErr {
Python::with_gil(|py| {
let internal_monkeypatch = py.import("baml_py.internal_monkeypatch").unwrap();
let exception = internal_monkeypatch.getattr("BamlClientFinishReasonError").unwrap();
let exception = internal_monkeypatch
.getattr("BamlClientFinishReasonError")
.unwrap();
let args = (prompt, message, raw_output, finish_reason);
let inst = exception.call1(args).unwrap();
PyErr::from_value(inst)
Expand All @@ -40,23 +47,18 @@ fn raise_baml_client_finish_reason_error(prompt: String, raw_output: String, mes
/// IIRC the name of this function is the name of the module that pyo3 generates (errors.py)
#[pymodule]
pub fn errors(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
parent_module.add(
"BamlError",
parent_module.py().get_type_bound::<BamlError>(),
)?;
parent_module.add("BamlError", parent_module.py().get_type::<BamlError>())?;
parent_module.add(
"BamlInvalidArgumentError",
parent_module
.py()
.get_type_bound::<BamlInvalidArgumentError>(),
parent_module.py().get_type::<BamlInvalidArgumentError>(),
)?;
parent_module.add(
"BamlClientError",
parent_module.py().get_type_bound::<BamlClientError>(),
parent_module.py().get_type::<BamlClientError>(),
)?;
parent_module.add(
"BamlClientHttpError",
parent_module.py().get_type_bound::<BamlClientHttpError>(),
parent_module.py().get_type::<BamlClientHttpError>(),
)?;

Ok(())
Expand All @@ -80,9 +82,12 @@ impl BamlError {
raw_output,
message,
finish_reason,
} => {
raise_baml_client_finish_reason_error(prompt.clone(), raw_output.clone(), message.clone(), finish_reason.clone())
}
} => raise_baml_client_finish_reason_error(
prompt.clone(),
raw_output.clone(),
message.clone(),
finish_reason.clone(),
),
}
} else if let Some(er) = err.downcast_ref::<ScopeStack>() {
PyErr::new::<BamlInvalidArgumentError, _>(format!("Invalid argument: {}", er))
Expand Down
3 changes: 2 additions & 1 deletion engine/language_client_python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ mod runtime;
mod types;

use pyo3::prelude::{pyfunction, pymodule, PyAnyMethods, PyModule, PyResult};
use pyo3::types::PyModuleMethods;
use pyo3::{wrap_pyfunction, Bound, Python};
use tracing_subscriber::{self, EnvFilter};

#[pyfunction]
fn invoke_runtime_cli(py: Python) -> PyResult<()> {
baml_cli::run_cli(
py.import_bound("sys")?
py.import("sys")?
.getattr("argv")?
.extract::<Vec<String>>()?,
baml_runtime::RuntimeCliDefaults {
Expand Down
32 changes: 15 additions & 17 deletions engine/language_client_python/src/parse_py_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::{
exceptions::{PyRuntimeError, PyTypeError},
prelude::{PyAnyMethods, PyTypeMethods},
types::{PyBool, PyBoolMethods, PyDict, PyDictMethods, PyList},
PyErr, PyObject, PyResult, Python, ToPyObject,
IntoPyObjectExt, PyErr, PyObject, PyResult, Python,
};

use crate::types::{BamlAudioPy, BamlImagePy};
Expand Down Expand Up @@ -209,10 +209,8 @@ pub fn parse_py_type(
serialize_unknown_types_as_str: bool,
) -> PyResult<Option<BamlValue>> {
Python::with_gil(|py| {
let enum_type = py.import_bound("enum").and_then(|m| m.getattr("Enum"))?;
let base_model = py
.import_bound("pydantic")
.and_then(|m| m.getattr("BaseModel"))?;
let enum_type = py.import("enum").and_then(|m| m.getattr("Enum"))?;
let base_model = py.import("pydantic").and_then(|m| m.getattr("BaseModel"))?;

let mut get_type = |py: Python<'_>,
any: PyObject,
Expand All @@ -227,10 +225,10 @@ pub fn parse_py_type(
let name = t
.name()
.map(|n| {
if let Some(x) = n.rfind("baml_client.types.") {
n[x + "baml_client.types.".len()..].to_string()
} else {
n.to_string()
let n = n.to_string();
match n.strip_prefix("baml_client.types.") {
Some(s) => s.to_string(),
None => n,
}
})
.unwrap_or("<UnnamedEnum>".to_string());
Expand All @@ -241,10 +239,10 @@ pub fn parse_py_type(
let name = t
.name()
.map(|n| {
if let Some(x) = n.rfind("baml_client.types.") {
n[x + "baml_client.types.".len()..].to_string()
} else {
n.to_string()
let n = n.to_string();
match n.strip_prefix("baml_client.types.") {
Some(s) => s.to_string(),
None => n,
}
})
.unwrap_or("<UnnamedBaseModel>".to_string());
Expand All @@ -256,7 +254,7 @@ pub fn parse_py_type(
{
for (key, _) in model_fields {
if let Ok(value) = any.getattr(py, key.as_str()) {
fields.insert(key, value.to_object(py));
fields.insert(key, value.into_py_any(py)?);
}
}
}
Expand All @@ -266,7 +264,7 @@ pub fn parse_py_type(
if let Ok(extra_dict) = extra.downcast_bound::<PyDict>(py) {
for (key, value) in extra_dict.iter() {
if let (Ok(key), value) = (key.extract::<String>(), value) {
fields.insert(key, value.to_object(py));
fields.insert(key, value.into_py_any(py)?);
}
}
}
Expand All @@ -276,7 +274,7 @@ pub fn parse_py_type(
// log::info!("Fields of {}:", name);
// for (key, value) in &fields {
// let repr = py
// .import_bound("builtins")?
// .import("builtins")?
// .getattr("repr")?
// .call1((value,))?;
// let repr_str = repr.extract::<String>()?;
Expand All @@ -288,7 +286,7 @@ pub fn parse_py_type(
let mut items = vec![];
let len = list.len()?;
for idx in 0..len {
items.push(list.get_item(idx)?.to_object(py));
items.push(list.get_item(idx)?.into_py_any(py)?);
}
Ok(MappedPyType::List(items))
} else if let Ok(kv) = any.extract::<HashMap<String, PyObject>>(py) {
Expand Down
Loading

0 comments on commit 44b57f7

Please sign in to comment.