Skip to content

Commit

Permalink
Fix the reference count stealing problem in clif::callback.
Browse files Browse the repository at this point in the history
This is a follow up of cl/532595246. When the first parameter of the callback function is not `PyObject*`, and the second parameter of the callback function is `PyObject*`, the newly added function template specialization for `PyObject*` is not executed. See b/282776731#comment6 for details.

TGP only includes unrelated failures: https://fusion2.corp.google.com/presubmit/568959836/OCL:568959836:BASE:569532419:1696009662955:2e10d708/targets.

PiperOrigin-RevId: 576333623
  • Loading branch information
wangxf authored and Ralf W. Grosse-Kunstleve committed Nov 14, 2023
1 parent e56383f commit 0ced3b5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 5 additions & 0 deletions clif/python/stltypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ inline Py_ssize_t ArgIn(PyObject**, Py_ssize_t idx, const py::PostConv&) {
return idx;
}

// NOTE: This forward declaration is CRITICAL (see b/282776731#comment6).
template <typename... T>
Py_ssize_t ArgIn(PyObject** a, Py_ssize_t idx, const py::PostConv& pc,
PyObject* c1, T&&... c);

template <typename T1, typename... T>
Py_ssize_t ArgIn(PyObject** a, Py_ssize_t idx, const py::PostConv& pc, T1&& c1,
T&&... c) {
Expand Down
22 changes: 20 additions & 2 deletions clif/testing/python/pyobject_ptr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,40 @@ def cb(pvh):
self.assertEqual(cc(cb_guarded, PyValueHolder(3.0)).value, -123)
self.assertIn("ValueError: Unknown pvh.value: 3.0", sio.getvalue())

def test_call_callback_with_pyobject_ptr_int_args(self):
def test_call_callback_with_pyobject_ptr_int_args_temporary_arg(self):
def cb(pvh, num):
return tst.CppValueHolder(pvh.value * 10 + num)

cc = tst.call_callback_with_pyobject_ptr_int_args
for _ in range(1000):
self.assertEqual(cc(cb, PyValueHolder(30)).value, 340)

def test_call_callback_with_int_pyobject_ptr_args(self):
def test_call_callback_with_pyobject_ptr_int_args_named_arg(self):
def cb(pvh, num):
return tst.CppValueHolder(pvh.value * 10 + num)

cc = tst.call_callback_with_pyobject_ptr_int_args
value_holder = PyValueHolder(30)
for _ in range(1000):
self.assertEqual(cc(cb, value_holder).value, 340)

def test_call_callback_with_int_pyobject_ptr_args_temporary_arg(self):
def cb(num, pvh):
return tst.CppValueHolder(num * 20 + pvh.value)

cc = tst.call_callback_with_int_pyobject_ptr_args
for _ in range(1000):
self.assertEqual(cc(cb, PyValueHolder(60)).value, 1060)

def test_call_callback_with_int_pyobject_ptr_args_named_arg(self):
def cb(num, pvh):
return tst.CppValueHolder(num * 20 + pvh.value)

cc = tst.call_callback_with_int_pyobject_ptr_args
value_holder = PyValueHolder(60)
for _ in range(1000):
self.assertEqual(cc(cb, value_holder).value, 1060)


if __name__ == "__main__":
absltest.main()

0 comments on commit 0ced3b5

Please sign in to comment.