Skip to content

Commit

Permalink
PR tensorflow#21265: Attempt to add pmap free-threading support
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#21265

Description:
- A tentative to add free-threading to pmap_lib
Copybara import of the project:

--
d2f5df9c0decdb7e55a2013f5506dee6fc358298 by vfdev-5 <[email protected]>:

WIP

Merging this change closes tensorflow#21265

PiperOrigin-RevId: 714696059
  • Loading branch information
vfdev-5 authored and tensorflower-gardener committed Jan 12, 2025
1 parent d025a0e commit e26ce85
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions third_party/xla/xla/python/pmap_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,14 @@ class PmapFunction {
return inspect->attr("signature")(fun_);
}

int cache_size() const { return executables_.size(); }
void cache_clear() { return executables_.clear(); }
int cache_size() {
nb::ft_lock_guard lock(mu_);
return executables_.size();
}
void cache_clear() {
nb::ft_lock_guard lock(mu_);
return executables_.clear();
}
const nb::callable& fun() const { return fun_; }
const nb::callable& cache_miss() const { return cache_miss_; }
const std::string& function_name() const { return function_name_; }
Expand Down Expand Up @@ -406,7 +412,8 @@ class PmapFunction {
// cache and recompiles), the list of the string representations of the keys.
//
// The format can change at any time.
std::string DebugCacheKeys() const {
std::string DebugCacheKeys() {
nb::ft_lock_guard lock(mu_);
std::vector<std::string> key_strings = {
absl::StrCat("The cache contains ", executables_.size(), " elements:")};
// We will be able to use auto& [key, _] when TF uses C++ 17.
Expand Down Expand Up @@ -441,6 +448,9 @@ class PmapFunction {
// The fallback function to use with `ShardArgs`.
// TODO(jblespiau): Add support for more types from C++.
nb::callable python_shard_arg_fallback_;

// Protect methods in FT:
nb::ft_mutex mu_;
};

void PmapFunction::PopulateCacheEntry(PmapCacheEntry& cache_entry,
Expand Down Expand Up @@ -584,8 +594,11 @@ absl::StatusOr<nb::object> PmapFunction::Call(nb::handle callable,

// Retrieve/Maybe add the executable to the cache.
bool inserted = false;
std::shared_ptr<PmapCacheEntry>& cache_entry_ptr =
executables_[call_signature];
std::shared_ptr<PmapCacheEntry> cache_entry_ptr;
{
nb::ft_lock_guard lock(mu_);
cache_entry_ptr = executables_[call_signature];
}
if (cache_entry_ptr == nullptr) {
inserted = true;
cache_entry_ptr = std::make_shared<PmapCacheEntry>(pytree_registry_.get());
Expand Down

0 comments on commit e26ce85

Please sign in to comment.