diff --git a/cpp/src/wholememory_ops/temp_memory_handle.hpp b/cpp/src/wholememory_ops/temp_memory_handle.hpp index 7f74677ba..408d3bfa1 100644 --- a/cpp/src/wholememory_ops/temp_memory_handle.hpp +++ b/cpp/src/wholememory_ops/temp_memory_handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ class temp_memory_handle { ~temp_memory_handle() { free_memory(); } void* device_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -40,7 +40,7 @@ class temp_memory_handle { } void* host_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -49,7 +49,7 @@ class temp_memory_handle { } void* pinned_malloc(size_t elt_count, wholememory_dtype_t data_type) { - free_memory(); + free_data(); wholememory_tensor_description_t tensor_description; get_tensor_description(&tensor_description, elt_count, data_type); ptr_ = temp_mem_fns_->malloc_fn( @@ -57,6 +57,13 @@ class temp_memory_handle { return ptr_; } [[nodiscard]] void* pointer() const { return ptr_; } + void free_data() + { + if (ptr_ != nullptr) { + temp_mem_fns_->free_fn(memory_context_, temp_mem_fns_->global_context); + ptr_ = nullptr; + } + } void free_memory() { if (ptr_ != nullptr) { diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index aba4b6bea..d083a8abc 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -75,6 +75,11 @@ def free(self): torch_cpp_ext_lib.destroy_output_context(self.get_handle()) self.handle = 0 + def free_data(self): + self.tensor = None + if torch_cpp_ext_loaded and self.get_handle() != 0: + torch_cpp_ext_lib.free_context_data(self.get_handle()) + def torch_create_memory_context_env_fn( global_context: TorchEmptyGlobalContext, @@ -121,7 +126,7 @@ def torch_malloc_env_fn( def torch_free_env_fn( memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext ): - memory_context.free() + memory_context.free_data() class ExtContextWrapper(object): diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp index 15d2e5160..be0385d9c 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,4 +60,8 @@ void destroy_output_context(void* output_context) { destroy_torch_memory_context_func(output_context, nullptr); } +void free_context_data(void* output_context) { + torch_common_free_func(output_context, nullptr); +} + } // namespace wholegraph_torch diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp index f1dcbecdb..d805d24a5 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp @@ -1,3 +1,18 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -24,6 +39,11 @@ void wrapped_destroy_output_context(int64_t output_context) wholegraph_torch::destroy_output_context(reinterpret_cast(output_context)); } +void wrapped_free_context_data(int64_t output_context) +{ + wholegraph_torch::free_context_data(reinterpret_cast(output_context), nullptr); +} + torch::Tensor get_torch_tensor_from_output_context(int64_t output_context) { auto* torch_output_context = @@ -39,6 +59,7 @@ PYBIND11_MODULE(pylibwholegraph_torch_ext, m) m.def("get_stream", &wrapped_get_stream, "Get current CUDA stream."); m.def("create_output_context", &wrapped_create_output_context, "Create output memory context."); m.def("destroy_output_context", &wrapped_destroy_output_context, "Destroy output memory context."); + m.def("free_context_data", &wrapped_free_context_data, "Free data in output memory context."); m.def("get_tensor_from_context", &get_torch_tensor_from_output_context, "Get PyTorch Tensor from output memory context");