From ebaad559d6ea4a49fa87123adf50c9aa847ce27a Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Fri, 22 Jan 2021 16:06:18 -0600 Subject: [PATCH] Fix a bug involving thread local memory initialization --- tests/compile/test_compilelock.py | 33 +++++++++++++++++++++++++++++++ theano/compile/compilelock.py | 8 ++++++-- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_compilelock.py b/tests/compile/test_compilelock.py index a77f6674f5..ee2c7724ec 100644 --- a/tests/compile/test_compilelock.py +++ b/tests/compile/test_compilelock.py @@ -2,6 +2,8 @@ import os import sys import tempfile +import threading +import time import filelock import pytest @@ -78,6 +80,37 @@ def run_locking_test(ctx): assert get_subprocess_lock_state(ctx, dir_name) == "unlocked" +def test_locking_thread(): + + with tempfile.TemporaryDirectory() as dir_name: + + def test_fn_1(): + with lock_ctx(dir_name): + # Sleep "indefinitely" + time.sleep(100) + + def test_fn_2(arg): + try: + with lock_ctx(dir_name, timeout=0.1): + # If this can get the lock, then our file lock has failed + raise AssertionError() + except filelock.Timeout: + # It timed out, which means that the lock was still held by the + # first thread + arg.append(True) + + thread_1 = threading.Thread(target=test_fn_1) + res = [] + thread_2 = threading.Thread(target=test_fn_2, args=(res,)) + + thread_1.start() + thread_2.start() + + # The second thread should raise `filelock.Timeout` + thread_2.join() + assert True in res + + @pytest.mark.skipif(sys.platform != "linux", reason="Fork is only available on linux") def test_locking_multiprocess_fork(): ctx = multiprocessing.get_context("fork") diff --git a/theano/compile/compilelock.py b/theano/compile/compilelock.py index 227306bc06..719db9a597 100644 --- a/theano/compile/compilelock.py +++ b/theano/compile/compilelock.py @@ -18,8 +18,12 @@ ] -local_mem = threading.local() -local_mem._locks: typing.Dict[str, bool] = {} +class ThreadFileLocks(threading.local): + def __init__(self): + self._locks = {} + + +local_mem = ThreadFileLocks() def force_unlock(lock_dir: os.PathLike):