From 3b5f6af2eaa0519643ccc2a4c1395307bfd3ad7e Mon Sep 17 00:00:00 2001 From: Mark Harris <783069+harrism@users.noreply.github.com> Date: Wed, 20 Nov 2024 12:22:49 +1100 Subject: [PATCH] Query total memory in failure_callback_resource_adaptor tests (#1734) Fixes #1733 by querying total device memory and using twice as much in tests that are expected to fail allocation. Authors: - Mark Harris (https://github.com/harrism) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/rmm/pull/1734 --- python/rmm/rmm/tests/test_rmm.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/python/rmm/rmm/tests/test_rmm.py b/python/rmm/rmm/tests/test_rmm.py index d7d692287..182434dc5 100644 --- a/python/rmm/rmm/tests/test_rmm.py +++ b/python/rmm/rmm/tests/test_rmm.py @@ -795,10 +795,28 @@ def callback(nbytes: int) -> bool: rmm.mr.set_current_device_resource(mr) with pytest.raises(MemoryError): - rmm.DeviceBuffer(size=int(1e11)) + from rmm.mr import available_device_memory + + total_memory = available_device_memory()[1] + rmm.DeviceBuffer(size=total_memory * 2) assert retried[0] +def test_failure_callback_resource_adaptor_error(): + def callback(nbytes: int) -> bool: + raise RuntimeError("MyError") + + cuda_mr = rmm.mr.CudaMemoryResource() + mr = rmm.mr.FailureCallbackResourceAdaptor(cuda_mr, callback) + rmm.mr.set_current_device_resource(mr) + + with pytest.raises(RuntimeError, match="MyError"): + from rmm.mr import available_device_memory + + total_memory = available_device_memory()[1] + rmm.DeviceBuffer(size=total_memory * 2) + + @pytest.mark.parametrize("managed", [True, False]) def test_prefetch_resource_adaptor(managed): if managed: @@ -823,18 +841,6 @@ def test_prefetch_resource_adaptor(managed): assert_prefetched(db, device) -def test_failure_callback_resource_adaptor_error(): - def callback(nbytes: int) -> bool: - raise RuntimeError("MyError") - - cuda_mr = rmm.mr.CudaMemoryResource() - mr = rmm.mr.FailureCallbackResourceAdaptor(cuda_mr, callback) - rmm.mr.set_current_device_resource(mr) - - with pytest.raises(RuntimeError, match="MyError"): - rmm.DeviceBuffer(size=int(1e11)) - - def test_dev_buf_circle_ref_dealloc(): # This test creates a reference cycle containing a `DeviceBuffer` # and ensures that the garbage collector does not clear it, i.e.,