From a8008de75ee3a45d7871b53d114204a3fe1de617 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Mon, 23 Sep 2024 10:35:55 -0600 Subject: [PATCH] Rebase to master --- setup.py | 3 +- src/main/c/jpy_jmethod.c | 4 + src/main/c/jpy_jobj.c | 2 +- src/main/c/jpy_jtype.c | 66 ++++++++++ src/main/c/jpy_module.c | 3 + .../MultiThreadedEvalTestFixture.java | 54 ++++++++ src/test/python/jpy_eval_exec_test.py | 1 + src/test/python/jpy_mt_eval_exec_test.py | 121 ++++++++++++++++++ 8 files changed, 252 insertions(+), 2 deletions(-) create mode 100644 src/test/java/org/jpy/fixtures/MultiThreadedEvalTestFixture.java create mode 100644 src/test/python/jpy_mt_eval_exec_test.py diff --git a/setup.py b/setup.py index e370f8e7..f9891973 100644 --- a/setup.py +++ b/setup.py @@ -97,6 +97,7 @@ os.path.join(src_test_py_dir, 'jpy_java_embeddable_test.py'), os.path.join(src_test_py_dir, 'jpy_obj_test.py'), os.path.join(src_test_py_dir, 'jpy_eval_exec_test.py'), + os.path.join(src_test_py_dir, 'jpy_mt_eval_exec_test.py'), ] # e.g. jdk_home_dir = '/home/marta/jdk1.7.0_15' @@ -279,7 +280,7 @@ def test_python_with_java_classes(self): def test_java(self): assert test_maven() - suite.addTest(test_python_with_java_runtime) + # suite.addTest(test_python_with_java_runtime) suite.addTest(test_python_with_java_classes) # comment out because the asynchronous nature of the PyObject GC in Java makes stopPython/startPython flakey. # suite.addTest(test_java) diff --git a/src/main/c/jpy_jmethod.c b/src/main/c/jpy_jmethod.c index e4469e95..ab49bd31 100644 --- a/src/main/c/jpy_jmethod.c +++ b/src/main/c/jpy_jmethod.c @@ -799,6 +799,8 @@ JPy_JMethod* JOverloadedMethod_FindMethod0(JNIEnv* jenv, JPy_JOverloadedMethod* overloadedMethod->declaringClass->javaName, JPy_AS_UTF8(overloadedMethod->name), overloadCount, argCount); for (i = 0; i < overloadCount; i++) { + // borrowed reference but no need to replace it with PyList_GetItemRef(), because the list won't be + // changed concurrently currMethod = (JPy_JMethod*) PyList_GetItem(overloadedMethod->methodList, i); if (currMethod->isVarArgs && matchValueMax > 0 && !bestMethod->isVarArgs) { @@ -950,6 +952,8 @@ int JOverloadedMethod_AddMethod(JPy_JOverloadedMethod* overloadedMethod, JPy_JMe // we need to insert this before the first varargs method Py_ssize_t size = PyList_Size(overloadedMethod->methodList); for (ii = 0; ii < size; ii++) { + // borrowed reference but no need to replace it with PyList_GetItemRef(), because the list won't be + // changed concurrently PyObject *check = PyList_GetItem(overloadedMethod->methodList, ii); if (((JPy_JMethod *) check)->isVarArgs) { // this is the first varargs method, so we should go before it diff --git a/src/main/c/jpy_jobj.c b/src/main/c/jpy_jobj.c index 84880a27..06687d5b 100644 --- a/src/main/c/jpy_jobj.c +++ b/src/main/c/jpy_jobj.c @@ -72,7 +72,7 @@ PyObject* JObj_FromType(JNIEnv* jenv, JPy_JType* type, jobject objectRef) } -// we check the type translations dictionary for a callable for this java type name, + // we check the type translations dictionary for a callable for this java type name, // and apply the returned callable to the wrapped object callable = PyDict_GetItemString(JPy_Type_Translations, type->javaName); if (callable != NULL) { diff --git a/src/main/c/jpy_jtype.c b/src/main/c/jpy_jtype.c index 7798009d..f1877bca 100644 --- a/src/main/c/jpy_jtype.c +++ b/src/main/c/jpy_jtype.c @@ -27,6 +27,54 @@ #include "jpy_conv.h" #include "jpy_compat.h" +#ifdef Py_GIL_DISABLED +typedef struct { + PyMutex lock; + PyThreadState* owner; + int recursion_level; +} ReentrantLock; + +static void acquire_lock(ReentrantLock* self) { + PyThreadState* current_thread = PyThreadState_Get(); + + if (self->owner == current_thread) { + self->recursion_level++; + return; + } + + PyMutex_Lock(&(self->lock)); + + self->owner = current_thread; + self->recursion_level = 1; +} + +static void release_lock(ReentrantLock* self) { + if (self->owner != PyThreadState_Get()) { + PyErr_SetString(PyExc_RuntimeError, "Lock not owned by current thread"); + return; + } + + self->recursion_level--; + if (self->recursion_level == 0) { + self->owner = NULL; + PyMutex_Unlock(&(self->lock)); + } +} + +static ReentrantLock get_type_rlock = {{0}, NULL, 0}; +static ReentrantLock resolve_type_rlock = {{0}, NULL, 0}; + +#define ACQUIRE_GET_TYPE_LOCK() acquire_lock(&get_type_rlock) +#define RELEASE_GET_TYPE_LOCK() release_lock(&get_type_rlock) +#define ACQUIRE_RESOLVE_TYPE_LOCK() acquire_lock(&resolve_type_rlock) +#define RELEASE_RESOLVE_TYPE_LOCK() release_lock(&resolve_type_rlock) + +#else +#define ACQUIRE_GET_TYPE_LOCK() +#define RELEASE_GET_TYPE_LOCK() +#define ACQUIRE_RESOLVE_TYPE_LOCK() +#define RELEASE_RESOLVE_TYPE_LOCK() +#endif JPy_JType* JType_New(JNIEnv* jenv, jclass classRef, jboolean resolve); int JType_ResolveType(JNIEnv* jenv, JPy_JType* type); @@ -52,6 +100,8 @@ static int JType_MatchVarArgPyArgAsFPType(const JPy_ParamDescriptor *paramDescri static int JType_MatchVarArgPyArgIntType(const JPy_ParamDescriptor *paramDescriptor, PyObject *pyArg, int idx, struct JPy_JType *expectedComponentType); + + JPy_JType* JType_GetTypeForObject(JNIEnv* jenv, jobject objectRef, jboolean resolve) { JPy_JType* type; @@ -151,6 +201,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) return NULL; } + ACQUIRE_GET_TYPE_LOCK(); typeValue = PyDict_GetItem(JPy_Types, typeKey); if (typeValue == NULL) { @@ -160,6 +211,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) type = JType_New(jenv, classRef, resolve); if (type == NULL) { JPy_DECREF(typeKey); + RELEASE_GET_TYPE_LOCK(); return NULL; } @@ -184,6 +236,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) PyDict_DelItem(JPy_Types, typeKey); JPy_DECREF(typeKey); JPy_DECREF(type); + RELEASE_GET_TYPE_LOCK(); return NULL; } @@ -195,6 +248,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) PyDict_DelItem(JPy_Types, typeKey); JPy_DECREF(typeKey); JPy_DECREF(type); + RELEASE_GET_TYPE_LOCK(); return NULL; } @@ -206,6 +260,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) PyDict_DelItem(JPy_Types, typeKey); JPy_DECREF(typeKey); JPy_DECREF(type); + RELEASE_GET_TYPE_LOCK(); return NULL; } @@ -231,6 +286,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) "jpy internal error: attributes in 'jpy.%s' must be of type '%s', but found a '%s'", JPy_MODULE_ATTR_NAME_TYPES, JType_Type.tp_name, Py_TYPE(typeValue)->tp_name); JPy_DECREF(typeKey); + RELEASE_GET_TYPE_LOCK(); return NULL; } @@ -240,6 +296,7 @@ JPy_JType* JType_GetType(JNIEnv* jenv, jclass classRef, jboolean resolve) } JPy_DIAG_PRINT(JPy_DIAG_F_TYPE, "JType_GetType: javaName=\"%s\", found=%d, resolve=%d, resolved=%d, type=%p\n", type->javaName, found, resolve, type->isResolved, type); + RELEASE_GET_TYPE_LOCK(); if (!type->isResolved && resolve) { if (JType_ResolveType(jenv, type) < 0) { @@ -968,7 +1025,10 @@ int JType_ResolveType(JNIEnv* jenv, JPy_JType* type) { PyTypeObject* typeObj; + ACQUIRE_RESOLVE_TYPE_LOCK(); + if (type->isResolved || type->isResolving) { + RELEASE_RESOLVE_TYPE_LOCK(); return 0; } @@ -980,6 +1040,7 @@ int JType_ResolveType(JNIEnv* jenv, JPy_JType* type) if (!baseType->isResolved) { if (JType_ResolveType(jenv, baseType) < 0) { type->isResolving = JNI_FALSE; + RELEASE_RESOLVE_TYPE_LOCK(); return -1; } } @@ -988,24 +1049,29 @@ int JType_ResolveType(JNIEnv* jenv, JPy_JType* type) //printf("JType_ResolveType 1\n"); if (JType_ProcessClassConstructors(jenv, type) < 0) { type->isResolving = JNI_FALSE; + RELEASE_RESOLVE_TYPE_LOCK(); return -1; } //printf("JType_ResolveType 2\n"); if (JType_ProcessClassMethods(jenv, type) < 0) { type->isResolving = JNI_FALSE; + RELEASE_RESOLVE_TYPE_LOCK(); return -1; } //printf("JType_ResolveType 3\n"); if (JType_ProcessClassFields(jenv, type) < 0) { type->isResolving = JNI_FALSE; + RELEASE_RESOLVE_TYPE_LOCK(); return -1; } //printf("JType_ResolveType 4\n"); type->isResolving = JNI_FALSE; type->isResolved = JNI_TRUE; + + RELEASE_RESOLVE_TYPE_LOCK(); return 0; } diff --git a/src/main/c/jpy_module.c b/src/main/c/jpy_module.c index 3c4afe8a..6ffd90d1 100644 --- a/src/main/c/jpy_module.c +++ b/src/main/c/jpy_module.c @@ -323,6 +323,9 @@ PyMODINIT_FUNC JPY_MODULE_INIT_FUNC(void) if (JPy_Module == NULL) { JPY_RETURN(NULL); } +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(JPy_Module, Py_MOD_GIL_NOT_USED); +#endif #elif defined(JPY_COMPAT_27) JPy_Module = Py_InitModule3(JPY_MODULE_NAME, JPy_Functions, JPY_MODULE_DOC); if (JPy_Module == NULL) { diff --git a/src/test/java/org/jpy/fixtures/MultiThreadedEvalTestFixture.java b/src/test/java/org/jpy/fixtures/MultiThreadedEvalTestFixture.java new file mode 100644 index 00000000..f7ca69f3 --- /dev/null +++ b/src/test/java/org/jpy/fixtures/MultiThreadedEvalTestFixture.java @@ -0,0 +1,54 @@ +package org.jpy.fixtures; + +import org.jpy.PyInputMode; +import org.jpy.PyLib; +import org.jpy.PyObject; + +import java.util.List; + +public class MultiThreadedEvalTestFixture { + + public static void expression(String expression, int numThreads) { + PyObject globals = PyLib.getCurrentGlobals(); + PyObject locals = PyLib.getCurrentLocals(); + + List threads = new java.util.ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + threads.add(new Thread(() -> { + PyObject.executeCode(expression, PyInputMode.EXPRESSION, globals, locals); + })); + } + for (Thread thread : threads) { + thread.start(); + } + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + public static void script(String expression, int numThreads) { + List threads = new java.util.ArrayList<>(); + PyObject globals = PyLib.getCurrentGlobals(); + PyObject locals = PyLib.getCurrentLocals(); + for (int i = 0; i < numThreads; i++) { + threads.add(new Thread(() -> { + PyObject.executeCode(expression, PyInputMode.SCRIPT, globals, locals); + })); + } + for (Thread thread : threads) { + thread.start(); + } + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + +} diff --git a/src/test/python/jpy_eval_exec_test.py b/src/test/python/jpy_eval_exec_test.py index e7bc8798..46f4530b 100644 --- a/src/test/python/jpy_eval_exec_test.py +++ b/src/test/python/jpy_eval_exec_test.py @@ -5,6 +5,7 @@ jpyutil.init_jvm(jvm_maxmem='512M', jvm_classpath=['target/classes', 'target/test-classes']) import jpy + class TestEvalExec(unittest.TestCase): def setUp(self): self.fixture = jpy.get_type("org.jpy.fixtures.EvalTestFixture") diff --git a/src/test/python/jpy_mt_eval_exec_test.py b/src/test/python/jpy_mt_eval_exec_test.py new file mode 100644 index 00000000..dba7acc0 --- /dev/null +++ b/src/test/python/jpy_mt_eval_exec_test.py @@ -0,0 +1,121 @@ +import math +import unittest + +import jpyutil + +jpyutil.init_jvm(jvm_maxmem='512M', jvm_classpath=['target/classes', 'target/test-classes']) +import jpy +# jpy.diag.flags = jpy.diag.F_TYPE + +NUM_THREADS = 20 + + +# A CPU-bound task: computing a large number of prime numbers +def is_prime(n: int) -> bool: + if n <= 1: + return False + for i in range(2, int(math.sqrt(n)) + 1): + if n % i == 0: + return False + return True + + +def count_primes(start: int, end: int) -> int: + count = 0 + for i in range(start, end): + if is_prime(i): + count += 1 + return count + + +def use_circular_java_classes(): + j_child1_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild1") + j_child2_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild2") + j_child2 = j_child2_class() + j_child1 = j_child1_class.of(8) + result = j_child1.parentMethod() + assert result == 88 + assert 888 == j_child1.grandParentMethod() + j_child1.refChild2(j_child2) + assert 8 == j_child1.get_x() + assert 10 == j_child1.y + assert 100 == j_child1.z + + +class MultiThreadedTestEvalExec(unittest.TestCase): + def setUp(self): + self.fixture = jpy.get_type("org.jpy.fixtures.MultiThreadedEvalTestFixture") + self.assertIsNotNone(self.fixture) + + def atest_inc_baz(self): + baz = 15 + self.fixture.script("baz = baz + 1; self.assertGreater(baz, 15)", NUM_THREADS) + # note: this *is* correct wrt python semantics w/ exec(code, globals(), locals()) + # https://bugs.python.org/issue4831 (Note: it's *not* a bug, is working as intended) + self.assertEqual(baz, 15) + + def atest_exec_import(self): + import sys + self.assertTrue("json" not in sys.modules) + self.fixture.script("import json", NUM_THREADS) + self.assertTrue("json" in sys.modules) + + def atest_exec_function_call(self): + self.fixture.expression("use_circular_java_classes()", NUM_THREADS) + + def test_count_primes(self): + self.fixture.expression("count_primes(1, 10000)", NUM_THREADS) + + def atest_java_threading_jpy_get_type(self): + + py_script = """ +j_child1_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild1") +j_child2_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild2") +j_child2 = j_child2_class() +j_child1 = j_child1_class.of(8) +result = j_child1.parentMethod() +assert result == 88 +assert 888 == j_child1.grandParentMethod() +j_child1.refChild2(j_child2) +assert 8 == j_child1.get_x() +assert 10 == j_child1.y +assert 100 == j_child1.z + """ + self.fixture.script(py_script, NUM_THREADS) + + def atest_py_threading_jpy_get_type(self): + import threading + + test_self = self + + class MyThread(threading.Thread): + def __init__(self): + threading.Thread.__init__(self) + + def run(self): + barrier.wait() + j_child1_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild1") + j_child2_class = jpy.get_type("org.jpy.fixtures.CyclicReferenceChild2") + j_child2 = j_child2_class() + j_child1 = j_child1_class.of(8) + test_self.assertEqual(88, j_child1.parentMethod()) + test_self.assertEqual(888, j_child1.grandParentMethod()) + test_self.assertIsNone(j_child1.refChild2(j_child2)) + test_self.assertEqual(8, j_child1.get_x()) + test_self.assertEqual(10, j_child1.y) + test_self.assertEqual(100, j_child1.z) + + barrier = threading.Barrier(NUM_THREADS) + threads = [] + for i in range(NUM_THREADS): + t = MyThread() + t.start() + threads.append(t) + + for t in threads: + t.join() + + +if __name__ == '__main__': + print('\nRunning ' + __file__) + unittest.main()