Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: unsubscribe pub/sub connections after cluster migration #4529

Merged
merged 8 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>

#include <memory>
#include <string_view>

#include "core/heap_size.h"
#include "facade/acl_commands_def.h"
Expand Down Expand Up @@ -34,6 +35,10 @@ class ConnectionContext {

virtual size_t UsedMemory() const;

// Noop.
virtual void Unsubscribe(std::string_view channel) {
}

// connection state / properties.
bool conn_closing : 1;
bool req_auth : 1;
Expand Down
13 changes: 13 additions & 0 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,17 @@ void Connection::AsyncOperations::operator()(const AclUpdateMessage& msg) {

void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;

if (pub_msg.should_unsubscribe) {
rbuilder->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rbuilder->SendBulkString("unsubscribe");
rbuilder->SendBulkString(pub_msg.channel);
rbuilder->SendLong(0);
auto* cntx = self->cntx();
cntx->Unsubscribe(pub_msg.channel);
return;
}

unsigned i = 0;
array<string_view, 4> arr;
if (pub_msg.pattern.empty()) {
Expand All @@ -502,8 +513,10 @@ void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
arr[i++] = "pmessage";
arr[i++] = pub_msg.pattern;
}

arr[i++] = pub_msg.channel;
arr[i++] = pub_msg.message;

rbuilder->SendBulkStrArr(absl::Span<string_view>{arr.data(), i},
RedisReplyBuilder::CollectionType::PUSH);
}
Expand Down
1 change: 1 addition & 0 deletions src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class Connection : public util::Connection {
std::string pattern{}; // non-empty for pattern subscriber
std::shared_ptr<char[]> buf; // stores channel name and message
std::string_view channel, message; // channel and message parts from buf
bool should_unsubscribe = false; // unsubscribe from channel after sending the message
};

// Pipeline message, accumulated Redis command to be executed.
Expand Down
110 changes: 104 additions & 6 deletions src/server/channel_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "base/logging.h"
#include "core/glob_matcher.h"
#include "server/cluster/slot_set.h"
#include "server/cluster_support.h"
#include "server/engine_shard_set.h"
#include "server/server_state.h"

Expand All @@ -17,7 +19,7 @@ using namespace std;
namespace {

// Build functor for sending messages to connection
auto BuildSender(string_view channel, facade::ArgRange messages) {
auto BuildSender(string_view channel, facade::ArgRange messages, bool unsubscribe = false) {
absl::FixedArray<string_view, 1> views(messages.Size());
size_t messages_size = accumulate(messages.begin(), messages.end(), 0,
[](int sum, string_view str) { return sum + str.size(); });
Expand All @@ -34,11 +36,12 @@ auto BuildSender(string_view channel, facade::ArgRange messages) {
}
}

return [channel, buf = std::move(buf), views = std::move(views)](facade::Connection* conn,
string pattern) {
return [channel, buf = std::move(buf), views = std::move(views), unsubscribe](
facade::Connection* conn, string pattern) {
string_view channel_view{buf.get(), channel.size()};
for (std::string_view message_view : views)
conn->SendPubMessageAsync({std::move(pattern), buf, channel_view, message_view});
for (std::string_view message_view : views) {
conn->SendPubMessageAsync({std::move(pattern), buf, channel_view, message_view, unsubscribe});
}
};
}

Expand Down Expand Up @@ -144,7 +147,6 @@ unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange m
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
ChannelStore::Subscriber::ByThreadId);
while (it != subscribers_ptr->end() && it->Thread() == idx) {
// if ptr->cntx() is null, a connection might have closed or be in the process of closing
if (auto* ptr = it->Get(); ptr && ptr->cntx() != nullptr)
send(ptr, it->pattern);
it++;
Expand Down Expand Up @@ -196,6 +198,45 @@ size_t ChannelStore::PatternCount() const {
return patterns_->size();
}

void ChannelStore::UnsubscribeAfterClusterSlotMigration(const cluster::SlotSet& deleted_slots) {
if (deleted_slots.Empty()) {
return;
}

const uint32_t tid = util::ProactorBase::me()->GetPoolIndex();
ChannelStoreUpdater csu(false, false, nullptr, tid);

for (const auto& [channel, _] : *channels_) {
auto channel_slot = KeySlot(channel);
if (deleted_slots.Contains(channel_slot)) {
csu.Record(channel);
}
}

csu.ApplyAndUnsubscribe();
}

void ChannelStore::UnsubscribeConnectionsFromDeletedSlots(const ChannelsSubMap& sub_map,
uint32_t idx) {
const bool should_unsubscribe = true;
for (const auto& [channel, subscribers] : sub_map) {
// ignored by pub sub handler because should_unsubscribe is true
std::string msg = "__ignore__";
auto send = BuildSender(channel, {facade::ArgSlice{msg}}, should_unsubscribe);

auto it = lower_bound(subscribers.begin(), subscribers.end(), idx,
ChannelStore::Subscriber::ByThreadId);
while (it != subscribers.end() && it->Thread() == idx) {
// if ptr->cntx() is null, a connection might have closed or be in the process of closing
if (auto* ptr = it->Get(); ptr && ptr->cntx() != nullptr) {
DCHECK(it->pattern.empty());
send(ptr, it->pattern);
}
++it;
}
}
}

ChannelStoreUpdater::ChannelStoreUpdater(bool pattern, bool to_add, ConnectionContext* cntx,
uint32_t thread_id)
: pattern_{pattern}, to_add_{to_add}, cntx_{cntx}, thread_id_{thread_id} {
Expand Down Expand Up @@ -295,4 +336,61 @@ void ChannelStoreUpdater::Apply() {
delete ptr;
}

void ChannelStoreUpdater::ApplyAndUnsubscribe() {
DCHECK(to_add_ == false);
DCHECK(pattern_ == false);
DCHECK(cntx_ == nullptr);

if (ops_.empty()) {
return;
}

// Wait for other updates to finish, lock the control block and update store pointer.
auto& cb = ChannelStore::control_block;
cb.update_mu.lock();
auto* store = cb.most_recent.load(memory_order_relaxed);

// Deep copy, we will remove channels
auto* target = new ChannelStore::ChannelMap{*store->channels_};

for (auto key : ops_) {
auto it = target->find(key);
freelist_.push_back(it->second.Get());
target->erase(it);
continue;
}

// Prepare replacement.
auto* replacement = new ChannelStore{target, store->patterns_};

// Update control block and unlock it.
cb.most_recent.store(replacement, memory_order_relaxed);
cb.update_mu.unlock();
Comment on lines +350 to +368
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use lock_guard

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually prefer using lock_guard with a scope {}. However, here I need to release the lock on line 367 but I need to use the variables defined in the lock scope store at the end of this function. Therefore, to avoid declaring those in an outerscope I rather do this manually via lock and unlock


// FetchSubscribers is not thead safe so we need to fetch here before we do the hop below.
// Bonus points because now we compute subscribers only once.
absl::flat_hash_map<std::string_view, std::vector<ChannelStore::Subscriber>> subs;
for (auto channel : ops_) {
auto channel_subs = ServerState::tlocal()->channel_store()->FetchSubscribers(channel);
DCHECK(!subs.contains(channel));
subs[channel] = std::move(channel_subs);
}
// Update thread local references. Readers fetch subscribers via FetchSubscribers,
// which runs without preemption, and store references to them in self container Subscriber
// structs. This means that any point on the other thread is safe to update the channel store.
// Regardless of whether we need to replace, we dispatch to make sure all
// queued SubscribeMaps in the freelist are no longer in use.
shard_set->pool()->AwaitFiberOnAll([&subs](unsigned idx, util::ProactorBase*) {
ServerState::tlocal()->UnsubscribeSlotsAndUpdateChannelStore(
subs, ChannelStore::control_block.most_recent.load(memory_order_relaxed));
});

// Delete previous map and channel store.
delete store->channels_;
delete store;

for (auto ptr : freelist_)
delete ptr;
}

} // namespace dfly
17 changes: 17 additions & 0 deletions src/server/channel_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ namespace dfly {

class ChannelStoreUpdater;

namespace cluster {
class SlotSet;
}

// ChannelStore manages PUB/SUB subscriptions.
//
// Updates are carried out via RCU (read-copy-update). Each thread stores a pointer to ChannelStore
Expand Down Expand Up @@ -61,8 +65,15 @@ class ChannelStore {
std::vector<Subscriber> FetchSubscribers(std::string_view channel) const;

std::vector<std::string> ListChannels(const std::string_view pattern) const;

size_t PatternCount() const;

void UnsubscribeAfterClusterSlotMigration(const cluster::SlotSet& deleted_slots);

using ChannelsSubMap =
absl::flat_hash_map<std::string_view, std::vector<ChannelStore::Subscriber>>;
void UnsubscribeConnectionsFromDeletedSlots(const ChannelsSubMap& sub_map, uint32_t idx);

// Destroy current instance and delete it.
static void Destroy();

Expand Down Expand Up @@ -128,6 +139,12 @@ class ChannelStoreUpdater {
void Record(std::string_view key);
void Apply();

// Used for cluster when slots migrate. We need to:
// 1. Remove the channel from the copy.
// 2. Unsuscribe all the connections from each channel.
// 3. Update the control block pointer.
void ApplyAndUnsubscribe();

private:
using ChannelMap = ChannelStore::ChannelMap;

Expand Down
5 changes: 5 additions & 0 deletions src/server/cluster/cluster_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "facade/dragonfly_connection.h"
#include "facade/error.h"
#include "server/acl/acl_commands_def.h"
#include "server/channel_store.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/dflycmd.h"
Expand Down Expand Up @@ -506,6 +507,10 @@ void DeleteSlots(const SlotRanges& slots_ranges) {
namespaces->GetDefaultNamespace().GetDbSlice(shard->shard_id()).FlushSlots(slots_ranges);
};
shard_set->pool()->AwaitFiberOnAll(std::move(cb));

auto* channel_store = ServerState::tlocal()->channel_store();
auto deleted = SlotSet(slots_ranges);
channel_store->UnsubscribeAfterClusterSlotMigration(deleted);
}

void WriteFlushSlotsToJournal(const SlotRanges& slot_ranges) {
Expand Down
12 changes: 12 additions & 0 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ size_t ConnectionContext::UsedMemory() const {
return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state);
}

void ConnectionContext::Unsubscribe(std::string_view channel) {
auto* sinfo = conn_state.subscribe_info.get();
DCHECK(sinfo);
auto erased = sinfo->channels.erase(channel);
DCHECK(erased);
if (sinfo->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(subscriptions, 1u);
--subscriptions;
}
}

vector<unsigned> ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern,
bool to_add, bool to_reply) {
vector<unsigned> result(to_reply ? channels.size() : 0, 0);
Expand Down
2 changes: 2 additions & 0 deletions src/server/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ class ConnectionContext : public facade::ConnectionContext {

size_t UsedMemory() const override;

virtual void Unsubscribe(std::string_view channel) override;

// Whether this connection is a connection from a replica to its master.
// This flag is true only on replica side, where we need to setup a special ConnectionContext
// instance that helps applying commands coming from master.
Expand Down
11 changes: 9 additions & 2 deletions src/server/server_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extern "C" {
#include "base/logging.h"
#include "facade/conn_context.h"
#include "facade/dragonfly_connection.h"
#include "server/channel_store.h"
#include "server/journal/journal.h"
#include "util/listener_interface.h"

Expand Down Expand Up @@ -261,8 +262,8 @@ void ServerState::ConnectionsWatcherFb(util::ListenerInterface* main) {
is_replica = dfly_conn->cntx()->replica_conn;
}

if ((phase == Phase::READ_SOCKET || dfly_conn->IsSending()) &&
!is_replica && dfly_conn->idle_time() > timeout) {
if ((phase == Phase::READ_SOCKET || dfly_conn->IsSending()) && !is_replica &&
dfly_conn->idle_time() > timeout) {
conn_refs.push_back(dfly_conn->Borrow());
}
};
Expand All @@ -285,4 +286,10 @@ void ServerState::ConnectionsWatcherFb(util::ListenerInterface* main) {
}
}

void ServerState::UnsubscribeSlotsAndUpdateChannelStore(const ChannelStore::ChannelsSubMap& sub_map,
ChannelStore* replacement) {
channel_store_->UnsubscribeConnectionsFromDeletedSlots(sub_map, thread_index_);
channel_store_ = replacement;
}

} // end of namespace dfly
4 changes: 4 additions & 0 deletions src/server/server_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/interpreter.h"
#include "server/acl/acl_log.h"
#include "server/acl/user_registry.h"
#include "server/channel_store.h"
#include "server/common.h"
#include "server/script_mgr.h"
#include "server/slowlog.h"
Expand Down Expand Up @@ -260,6 +261,9 @@ class ServerState { // public struct - to allow initialization.
channel_store_ = replacement;
}

void UnsubscribeSlotsAndUpdateChannelStore(const ChannelStore::ChannelsSubMap& sub_map,
ChannelStore* replacement);

bool ShouldLogSlowCmd(unsigned latency_usec) const;

Stats stats;
Expand Down
48 changes: 48 additions & 0 deletions tests/dragonfly/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2980,3 +2980,51 @@ async def test_cluster_sharded_pub_sub(df_factory: DflyInstanceFactory):
await c_nodes[0].execute_command("SPUBLISH kostas new_message")
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "unsubscribe", "pattern": None, "channel": b"kostas", "data": 0}


@dfly_args({"proactor_threads": 2, "cluster_mode": "yes"})
async def test_cluster_sharded_pub_sub_migration(df_factory: DflyInstanceFactory):
instances = [df_factory.create(port=next(next_port)) for i in range(2)]
df_factory.start_all(instances)

c_nodes = [instance.client() for instance in instances]

nodes = [(await create_node_info(instance)) for instance in instances]
nodes[0].slots = [(0, 16383)]
nodes[1].slots = []

await push_config(json.dumps(generate_config(nodes)), [node.client for node in nodes])

# Setup producer and consumer
node_a = ClusterNode("localhost", instances[0].port)
node_b = ClusterNode("localhost", instances[1].port)

consumer_client = RedisCluster(startup_nodes=[node_a, node_b])
consumer = consumer_client.pubsub()
consumer.ssubscribe("kostas")

# Push new config
nodes[0].migrations.append(
MigrationInfo("127.0.0.1", nodes[1].instance.port, [(0, 16383)], nodes[1].id)
)
await push_config(json.dumps(generate_config(nodes)), [node.client for node in nodes])

await wait_for_status(nodes[0].client, nodes[1].id, "FINISHED")

nodes[0].migrations = []
nodes[0].slots = []
nodes[1].slots = [(0, 16383)]
logging.debug("remove finished migrations")
await push_config(json.dumps(generate_config(nodes)), [node.client for node in nodes])

# channel name kostas crc is at slot 2883 which is part of the second now.
with pytest.raises(redis.exceptions.ResponseError) as moved_error:
await c_nodes[0].execute_command("SSUBSCRIBE kostas")

assert str(moved_error.value) == f"MOVED 2833 127.0.0.1:{instances[1].port}"

# Consume subscription message result from above
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "subscribe", "pattern": None, "channel": b"kostas", "data": 1}
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "unsubscribe", "pattern": None, "channel": b"kostas", "data": 0}
Loading