Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Jan 10, 2025
1 parent b322cfc commit d2f5df9
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions xla/python/pmap_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,14 @@ class PmapFunction {
return inspect->attr("signature")(fun_);
}

int cache_size() const { return executables_.size(); }
void cache_clear() { return executables_.clear(); }
int cache_size() {
nb::ft_lock_guard lock(mu_);
return executables_.size();
}
void cache_clear() {
nb::ft_lock_guard lock(mu_);
return executables_.clear();
}
const nb::callable& fun() const { return fun_; }
const nb::callable& cache_miss() const { return cache_miss_; }
const std::string& function_name() const { return function_name_; }
Expand Down Expand Up @@ -406,7 +412,8 @@ class PmapFunction {
// cache and recompiles), the list of the string representations of the keys.
//
// The format can change at any time.
std::string DebugCacheKeys() const {
std::string DebugCacheKeys() {
nb::ft_lock_guard lock(mu_);
std::vector<std::string> key_strings = {
absl::StrCat("The cache contains ", executables_.size(), " elements:")};
// We will be able to use auto& [key, _] when TF uses C++ 17.
Expand Down Expand Up @@ -441,6 +448,9 @@ class PmapFunction {
// The fallback function to use with `ShardArgs`.
// TODO(jblespiau): Add support for more types from C++.
nb::callable python_shard_arg_fallback_;

// Protect methods in FT:
nb::ft_mutex mu_;
};

void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry,
Expand Down Expand Up @@ -584,8 +594,11 @@ absl::StatusOr<nb::object> PmapFunction::Call(nb::handle callable,

// Retrieve/Maybe add the executable to the cache.
bool inserted = false;
std::shared_ptr<PmapCacheEntry>& cache_entry_ptr =
executables_[call_signature];
std::shared_ptr<PmapCacheEntry> cache_entry_ptr;
{
nb::ft_lock_guard lock(mu_);
cache_entry_ptr = executables_[call_signature];
}
if (cache_entry_ptr == nullptr) {
inserted = true;
cache_entry_ptr = std::make_shared<PmapCacheEntry>(pytree_registry_.get());
Expand Down

0 comments on commit d2f5df9

Please sign in to comment.