Skip to content

Commit

Permalink
refactor(interactive): Pass worker_id as a parameter instead of lazy_…
Browse files Browse the repository at this point in the history
…static (#4060)

Fixes #4173
  • Loading branch information
lnfjpt authored Aug 27, 2024
1 parent bbd05f8 commit d306ca6
Show file tree
Hide file tree
Showing 31 changed files with 205 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ impl<T: Decode> Decode for Option<T> {

mod shade;
mod third_party;

pub use shade::ShadeCodec;
#[cfg(feature = "serde")]
pub use third_party::serde_bin as serde;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::message::MessageHeader;
use crate::{NetError, Server};

mod encode;

pub use encode::{GeneralEncoder, MessageEncoder, SimpleEncoder, SlabEncoder};

mod net_tx;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,13 @@ impl<T: Data> Channel<T> {
}

impl<T: Data> Channel<T> {
fn build_pipeline(self, target: Port, id: ChannelId) -> MaterializedChannel<T> {
fn build_pipeline(self, target: Port, id: ChannelId, worker_id: u32) -> MaterializedChannel<T> {
let (tx, rx) = crate::data_plane::pipeline::<MicroBatch<T>>(id);
let scope_level = self.get_scope_level();
let ch_info = ChannelInfo::new(id, scope_level, 1, 1, self.source, target);
let push = MicroBatchPush::Pipeline(LocalMicroBatchPush::new(ch_info, tx));
let worker = crate::worker_id::get_current_worker().index;
let ch = CancelHandle::SC(SingleConsCancel::new(worker));
let push = PerChannelPush::new(ch_info, self.scope_delta, push, ch);
let push = MicroBatchPush::Pipeline(LocalMicroBatchPush::new(ch_info, tx, worker_id));
let ch = CancelHandle::SC(SingleConsCancel::new(worker_id));
let push = PerChannelPush::new(ch_info, self.scope_delta, push, ch, worker_id);
MaterializedChannel { push, pull: rx.into(), notify: None }
}

Expand All @@ -179,14 +178,14 @@ impl<T: Data> Channel<T> {
(ChannelInfo, Vec<EventEmitPush<T>>, GeneralPull<MicroBatch<T>>, GeneralPush<MicroBatch<T>>),
BuildJobError,
> {
let (mut raw, pull) = crate::communication::build_channel::<MicroBatch<T>>(id, &dfb.config)?.take();
let worker_index = crate::worker_id::get_current_worker().index as usize;
let (mut raw, pull) =
crate::communication::build_channel::<MicroBatch<T>>(id, &dfb.config, dfb.worker_id)?.take();
let worker_index = dfb.worker_id.index as usize;
let notify = raw.swap_remove(worker_index);
let ch_info = ChannelInfo::new(id, scope_level, raw.len(), raw.len(), self.source, target);
let mut pushes = Vec::with_capacity(raw.len());
let source = dfb.worker_id.index;
for (idx, p) in raw.into_iter().enumerate() {
let push = EventEmitPush::new(ch_info, source, idx as u32, p, dfb.event_emitter.clone());
let push = EventEmitPush::new(ch_info, dfb.worker_id, idx as u32, p, dfb.event_emitter.clone());
pushes.push(push);
}
Ok((ch_info, pushes, pull, notify))
Expand All @@ -212,51 +211,69 @@ impl<T: Data> Channel<T> {
}

if dfb.worker_id.total_peers() == 1 {
return Ok(self.build_pipeline(target, id));
return Ok(self.build_pipeline(target, id, dfb.worker_id.index));
}

let kind = std::mem::replace(&mut self.kind, ChannelKind::Pipeline);
match kind {
ChannelKind::Pipeline => Ok(self.build_pipeline(target, id)),
ChannelKind::Pipeline => Ok(self.build_pipeline(target, id, dfb.worker_id.index)),
ChannelKind::Shuffle(r) => {
let (info, pushes, pull, notify) = self.build_remote(scope_level, target, id, dfb)?;
let mut buffers = Vec::with_capacity(pushes.len());
for _ in 0..pushes.len() {
let b = ScopeBufferPool::new(batch_size, batch_capacity, scope_level);
buffers.push(b);
}
let push = ExchangeByDataPush::new(info, r, buffers, pushes);
let push = ExchangeByDataPush::new(info, r, buffers, pushes, dfb.worker_id);
let ch = push.get_cancel_handle();
let push = PerChannelPush::new(info, self.scope_delta, MicroBatchPush::Exchange(push), ch);
let push = PerChannelPush::new(
info,
self.scope_delta,
MicroBatchPush::Exchange(push),
ch,
dfb.worker_id.index,
);
Ok(MaterializedChannel { push, pull: pull.into(), notify: Some(notify) })
}
ChannelKind::BatchShuffle(route) => {
let (info, pushes, pull, notify) = self.build_remote(scope_level, target, id, dfb)?;
let push = ExchangeByBatchPush::new(info, route, pushes);
let push = ExchangeByBatchPush::new(info, route, pushes, dfb.worker_id);
let cancel = push.get_cancel_handle();
let push = PerChannelPush::new(
info,
self.scope_delta,
MicroBatchPush::ExchangeByBatch(push),
cancel,
dfb.worker_id.index,
);
Ok(MaterializedChannel { push, pull: pull.into(), notify: Some(notify) })
}
ChannelKind::Broadcast => {
let (info, pushes, pull, notify) = self.build_remote(scope_level, target, id, dfb)?;
let push = BroadcastBatchPush::new(info, pushes);
let push = BroadcastBatchPush::new(info, pushes, dfb.worker_id.total_peers());
let ch = push.get_cancel_handle();
let push = PerChannelPush::new(info, self.scope_delta, MicroBatchPush::Broadcast(push), ch);
let push = PerChannelPush::new(
info,
self.scope_delta,
MicroBatchPush::Broadcast(push),
ch,
dfb.worker_id.index,
);
Ok(MaterializedChannel { push, pull: pull.into(), notify: Some(notify) })
}
ChannelKind::Aggregate => {
let (mut ch_info, pushes, pull, notify) =
self.build_remote(scope_level, target, id, dfb)?;
ch_info.target_peers = 1;
let push = AggregateBatchPush::new(ch_info, pushes);
let push = AggregateBatchPush::new(ch_info, pushes, dfb.worker_id);
let cancel = push.get_cancel_handle();
let push =
PerChannelPush::new(ch_info, self.scope_delta, MicroBatchPush::Aggregate(push), cancel);
let push = PerChannelPush::new(
ch_info,
self.scope_delta,
MicroBatchPush::Aggregate(push),
cancel,
dfb.worker_id.index,
);
Ok(MaterializedChannel { push, pull: pull.into(), notify: Some(notify) })
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::communication::decorator::exchange::ExchangeByBatchPush;
use crate::data::MicroBatch;
use crate::data_plane::Push;
use crate::errors::IOError;
use crate::Data;
use crate::{Data, WorkerId};

struct ScopedAggregate<D: Data>(PhantomData<D>);

Expand All @@ -30,14 +30,18 @@ pub struct AggregateBatchPush<D: Data> {
}

impl<D: Data> AggregateBatchPush<D> {
pub fn new(info: ChannelInfo, pushes: Vec<EventEmitPush<D>>) -> Self {
pub fn new(info: ChannelInfo, pushes: Vec<EventEmitPush<D>>, worker_id: WorkerId) -> Self {
if info.scope_level == 0 {
let push = ExchangeByBatchPush::new(info, BatchRoute::AllToOne(0), pushes);
let push = ExchangeByBatchPush::new(info, BatchRoute::AllToOne(0), pushes, worker_id);
AggregateBatchPush { push }
} else {
let chancel_handle = DynSingleConsCancelPtr::new(info.scope_level, pushes.len());
let mut push =
ExchangeByBatchPush::new(info, BatchRoute::Dyn(Box::new(ScopedAggregate::new())), pushes);
let mut push = ExchangeByBatchPush::new(
info,
BatchRoute::Dyn(Box::new(ScopedAggregate::new())),
pushes,
worker_id,
);
push.update_cancel_handle(CancelHandle::DSC(chancel_handle));
AggregateBatchPush { push }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ pub struct BroadcastBatchPush<D: Data> {
pub ch_info: ChannelInfo,
pushes: Vec<EventEmitPush<D>>,
cancel_handle: MultiConsCancelPtr,
total_peers: u32,
}

impl<D: Data> BroadcastBatchPush<D> {
pub fn new(ch_info: ChannelInfo, pushes: Vec<EventEmitPush<D>>) -> Self {
pub fn new(ch_info: ChannelInfo, pushes: Vec<EventEmitPush<D>>, total_peers: u32) -> Self {
let cancel_handle = MultiConsCancelPtr::new(ch_info.scope_level, pushes.len());
BroadcastBatchPush { ch_info, pushes, cancel_handle }
BroadcastBatchPush { ch_info, pushes, cancel_handle, total_peers }
}

pub(crate) fn get_cancel_handle(&self) -> CancelHandle {
Expand All @@ -38,14 +39,14 @@ impl<D: Data> BroadcastBatchPush<D> {

if let Some(mut end) = batch.take_end() {
if end.peers().value() == 1 && end.peers_contains(self.pushes[target].source_worker) {
end.update_peers(DynPeers::all());
end.update_peers(DynPeers::all(self.total_peers), self.total_peers);
batch.set_end(end);
self.pushes[target].push(batch)?;
} else {
if !batch.is_empty() {
self.pushes[target].push(batch)?;
}
self.pushes[target].sync_end(end, DynPeers::all())?;
self.pushes[target].sync_end(end, DynPeers::all(self.total_peers))?;
}
} else {
self.pushes[target].push(batch)?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ use crate::event::emitter::EventEmitter;
use crate::event::{Event, EventKind};
use crate::progress::{DynPeers, EndOfScope, EndSyncSignal};
use crate::tag::tools::map::TidyTagMap;
use crate::PROFILE_COMM_FLAG;
use crate::{Data, Tag};
use crate::{WorkerId, PROFILE_COMM_FLAG};

#[allow(dead_code)]
pub struct EventEmitPush<T: Data> {
pub ch_info: ChannelInfo,
pub total_peers: u32,
pub source_worker: u32,
pub target_worker: u32,
inner: GeneralPush<MicroBatch<T>>,
Expand All @@ -38,13 +39,14 @@ pub struct EventEmitPush<T: Data> {
#[allow(dead_code)]
impl<T: Data> EventEmitPush<T> {
pub fn new(
info: ChannelInfo, source_worker: u32, target_worker: u32, push: GeneralPush<MicroBatch<T>>,
info: ChannelInfo, worker_id: WorkerId, target_worker: u32, push: GeneralPush<MicroBatch<T>>,
emitter: EventEmitter,
) -> Self {
let push_counts = TidyTagMap::new(info.scope_level);
EventEmitPush {
ch_info: info,
source_worker,
total_peers: worker_id.total_peers(),
source_worker: worker_id.index,
target_worker,
inner: push,
event_emitter: emitter,
Expand Down Expand Up @@ -75,14 +77,14 @@ impl<T: Data> EventEmitPush<T> {
end.peers(),
children
);
end.update_peers(children);
end.update_peers(children, self.total_peers);
let end_batch = MicroBatch::last(self.source_worker, end);
self.push(end_batch)
} else {
Ok(())
}
} else {
end.update_peers(children);
end.update_peers(children, self.total_peers);
let end_batch = MicroBatch::last(self.source_worker, end);
self.push(end_batch)
}
Expand Down
Loading

0 comments on commit d306ca6

Please sign in to comment.