diff --git a/python/rmm/rmm/tests/test_rmm.py b/python/rmm/rmm/tests/test_rmm.py index d7d692287..182434dc5 100644 --- a/python/rmm/rmm/tests/test_rmm.py +++ b/python/rmm/rmm/tests/test_rmm.py @@ -795,10 +795,28 @@ def callback(nbytes: int) -> bool: rmm.mr.set_current_device_resource(mr) with pytest.raises(MemoryError): - rmm.DeviceBuffer(size=int(1e11)) + from rmm.mr import available_device_memory + + total_memory = available_device_memory()[1] + rmm.DeviceBuffer(size=total_memory * 2) assert retried[0] +def test_failure_callback_resource_adaptor_error(): + def callback(nbytes: int) -> bool: + raise RuntimeError("MyError") + + cuda_mr = rmm.mr.CudaMemoryResource() + mr = rmm.mr.FailureCallbackResourceAdaptor(cuda_mr, callback) + rmm.mr.set_current_device_resource(mr) + + with pytest.raises(RuntimeError, match="MyError"): + from rmm.mr import available_device_memory + + total_memory = available_device_memory()[1] + rmm.DeviceBuffer(size=total_memory * 2) + + @pytest.mark.parametrize("managed", [True, False]) def test_prefetch_resource_adaptor(managed): if managed: @@ -823,18 +841,6 @@ def test_prefetch_resource_adaptor(managed): assert_prefetched(db, device) -def test_failure_callback_resource_adaptor_error(): - def callback(nbytes: int) -> bool: - raise RuntimeError("MyError") - - cuda_mr = rmm.mr.CudaMemoryResource() - mr = rmm.mr.FailureCallbackResourceAdaptor(cuda_mr, callback) - rmm.mr.set_current_device_resource(mr) - - with pytest.raises(RuntimeError, match="MyError"): - rmm.DeviceBuffer(size=int(1e11)) - - def test_dev_buf_circle_ref_dealloc(): # This test creates a reference cycle containing a `DeviceBuffer` # and ensures that the garbage collector does not clear it, i.e.,