Skip to content

Commit

Permalink
Let errors propagate
Browse files Browse the repository at this point in the history
  • Loading branch information
danhje committed Mar 30, 2024
1 parent e5df833 commit 0f916fc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 40 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ name = "quick_xmltodict"
crate-type = ["cdylib"]

[dependencies]
pyo3 = "*"
anyhow = "*"
pyo3 = { version = "*", features = ["anyhow"] }
quick-xml = "*"
79 changes: 40 additions & 39 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
use anyhow::Result;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyNone};
use quick_xml::events::Event;
use quick_xml::name::QName;
use quick_xml::reader::Reader;

trait QNameExt {
fn qn(&self) -> String;
fn qn(&self) -> Result<String>;
}

impl QNameExt for QName<'_> {
/// Returns the qualified name of the element (prefix:local_name).
fn qn(&self) -> String {
let mut name = std::str::from_utf8(self.local_name().as_ref()).unwrap().to_string();
fn qn(&self) -> Result<String> {
let mut name = std::str::from_utf8(self.local_name().as_ref())?.to_string();
if let Some(prefix) = self.prefix() {
name = format!("{}:{}", std::str::from_utf8(prefix.as_ref()).unwrap(), name);
name = format!("{}:{}", std::str::from_utf8(prefix.as_ref())?, name);
}
name
Ok(name)
}
}

Expand All @@ -36,99 +37,99 @@ impl ToPyObject for Value<'_> {
}
}

fn _update_dict<'a>(py: Python<'a>, d: &'a PyDict, tag_name: &str, value: &'a PyObject) {
if d.contains(tag_name).unwrap() {
let existing_val = d.get_item(tag_name).unwrap();
let list: &PyList;
if existing_val.unwrap().is_instance_of::<PyList>() {
list = existing_val.unwrap().downcast::<PyList>().unwrap();
} else {
list = PyList::new(py, existing_val);
fn _update_dict<'a>(py: Python<'a>, d: &'a PyDict, tag_name: &str, value: &'a PyObject) -> Result<()> {
match d.get_item(tag_name)? {
None => {
d.set_item(tag_name, value)?;
}
Some(existing_val) => {
let list: &PyList;
if existing_val.is_instance_of::<PyList>() {
list = existing_val.extract::<&PyList>()?;
} else {
list = PyList::new(py, vec![existing_val]);
}

list.append(value).unwrap();
d.set_item(tag_name, list).unwrap();
} else {
d.set_item(tag_name, value).unwrap();
list.append(value)?;
d.set_item(tag_name, list)?;
}
}
Ok(())
}

fn _parse<'a>(py: Python<'a>, xml: &'a str) -> &'a PyDict {
fn _parse<'a>(py: Python<'a>, xml: &'a str) -> Result<&'a PyDict> {
let mut reader = Reader::from_str(xml);
reader.trim_text(true);

let d = PyDict::new(py);
loop {
match reader.read_event() {
Err(e) => panic!("Error at position {}: {:?}", reader.buffer_position(), e),
Err(e) => return Err(e.into()),
Ok(Event::Eof) => break,
Ok(Event::Empty(e)) => {
let tag_name = e.name().qn();
let tag_name = e.name().qn()?;

let value: Value;
if e.attributes().count() == 0 {
value = Value::None;
} else {
let attrs = PyDict::new(py);
for attr in e.attributes() {
let attr = attr.unwrap();
attrs
.set_item(format!("@{}", attr.key.qn()), attr.unescape_value().unwrap())
.unwrap();
let attr = attr?;
attrs.set_item(format!("@{}", attr.key.qn()?), attr.unescape_value()?)?;
}
value = Value::Dict(attrs);
}
_update_dict(py, d, &tag_name, &value.to_object(py));
_update_dict(py, d, &tag_name, &value.to_object(py))?;
}
Ok(Event::Text(e)) => {
let text = e.unescape().unwrap();
d.set_item("#text".to_string(), text).unwrap();
let text = e.unescape()?;
d.set_item("#text".to_string(), text)?;
}
Ok(Event::Start(e)) => {
let tag_name = e.name().qn();
let tag_name = e.name().qn()?;

let mut value = Value::Dict(PyDict::new(py));
if e.attributes().count() > 0 {
for attr in e.attributes() {
let attr = attr.unwrap();
let attr = attr?;
match value {
Value::Dict(d) => {
d.set_item(format!("@{}", attr.key.qn()), attr.unescape_value().unwrap())
.unwrap();
d.set_item(format!("@{}", attr.key.qn()?), attr.unescape_value()?)?;
}
_ => unreachable!(),
}
}
}

let content = reader.read_text(e.name()).unwrap();
match value {
Value::Dict(d) => {
d.update(_parse(py, &content).as_mapping()).unwrap();
d.update(_parse(py, &(reader.read_text(e.name())?))?.as_mapping())?;
}
_ => unreachable!(),
}

match value {
Value::Dict(d) if d.len() == 1 && d.contains("#text").unwrap() => {
value = Value::Text(d.get_item("#text").unwrap().unwrap().extract::<String>().unwrap());
Value::Dict(d) if d.len() == 1 => {
if let Some(text) = d.get_item("#text")? {
value = Value::Text(text.extract::<String>()?);
}
}
_ => (),
}

_update_dict(py, d, &tag_name, &value.to_object(py));
_update_dict(py, d, &tag_name, &value.to_object(py))?;
}
_ => (),
}
}

d
Ok(d)
}

#[pyfunction]
fn parse(py: Python, xml: &str) -> PyResult<PyObject> {
let d = _parse(py, xml);
Ok(d.into())
Ok(_parse(py, xml)?.into())
}

#[pymodule]
Expand Down
25 changes: 25 additions & 0 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from pathlib import Path
from xml.parsers.expat import ExpatError

import pytest
from quick_xmltodict import parse as rustparse
Expand Down Expand Up @@ -90,6 +91,15 @@ def test_namespace_prefixed(parse):
assert parse(xml) == target


def test_namespace_prefixed_attr(parse):
xml = """
<a xmlns:ns="http://example.com">
<ns:b ns:name="value">text</ns:b>
</a>"""
target = {"a": {"@xmlns:ns": "http://example.com", "ns:b": {"@ns:name": "value", "#text": "text"}}}
assert parse(xml) == target


@pytest.fixture
def data_dir():
return Path(__file__).parent / "data"
Expand All @@ -107,3 +117,18 @@ def forecast_target(data_dir):

def test_forecast(parse, forecast_xml, forecast_target):
assert parse(forecast_xml) == forecast_target


def test_error_missing_closing_tag(parse):
with pytest.raises((RuntimeError, ExpatError)):
parse("<a>")


def test_error_missing_opening_tag(parse):
with pytest.raises((RuntimeError, ExpatError)):
parse("</a>")


def test_error_malformed_tag(parse):
with pytest.raises((RuntimeError, ExpatError)):
parse("<a")

0 comments on commit 0f916fc

Please sign in to comment.