diff --git a/avtdl/core/interfaces.py b/avtdl/core/interfaces.py index 5db5387..4b7f0c4 100644 --- a/avtdl/core/interfaces.py +++ b/avtdl/core/interfaces.py @@ -136,38 +136,35 @@ def _generic_topic(self, specific_topic: str) -> str: def pub(self, topic: str, message: Record): self.logger.debug(f'on topic {topic} message "{message!r}"') - if message.chain: - matching_callbacks = self.get_matching_callbacks(topic) - for generic_topic, callbacks in matching_callbacks.items(): - for callback in callbacks: - callback(topic, message) - for generic_topic in set(self._generic_topic(t) for t in matching_callbacks.keys()): - self.add_to_history(generic_topic, message) - self.add_to_history(topic, message) - else: - matching_callbacks = self.get_matching_callbacks(topic) - for specific_topic, callbacks in matching_callbacks.items(): + matching_callbacks = self.get_matching_callbacks(topic) + for specific_topic, callbacks in matching_callbacks.items(): + if message.chain: + targeted_message = message + else: _, _, chain = self.split_message_topic(specific_topic) targeted_message = message.model_copy(deep=True) targeted_message.chain = chain - for callback in callbacks: - callback(specific_topic, targeted_message) - self.add_to_history(specific_topic, targeted_message) - self.add_to_history(topic, message) + for callback in callbacks: + callback(specific_topic, targeted_message) + + generic_topic = self._generic_topic(specific_topic) + self.add_to_history(generic_topic, targeted_message) def add_to_history(self, topic: str, message: Record): self.history[topic].append(message) def get_history(self, actor: str, entity: str, chain: str = '', direction: Literal['in', 'out'] = 'in') -> List[Record]: if direction == 'in': - topic = self.incoming_topic_for(actor, entity, chain) + topic = self.incoming_topic_for(actor, entity, '') elif direction == 'out': - topic = self.outgoing_topic_for(actor, entity, chain) + topic = self.outgoing_topic_for(actor, entity, '') else: assert False, f'unexpected direction "{direction}"' - records: List[Record] = [] - records.extend(self.history[topic]) - records.sort(key = lambda r: r.created_at) + if chain: + records = [record for record in self.history[topic] if record.chain == chain] + else: + records = list(self.history[topic]) + records.sort(key = lambda record: record.created_at) return records def get_matching_callbacks(self, topic_pattern: str) -> SubscriptionsMapping: @@ -358,6 +355,7 @@ def handle_record(self, entity: ActorEntity, record: Record) -> None: def on_record(self, entity: ActorEntity, record: Record): '''Implementation should call it for every new Record it produces''' if entity.reset_origin: + record = record.model_copy(deep=True) record.chain = '' topic = self.bus.outgoing_topic_for(self.conf.name, entity.name, record.chain) self.bus.pub(topic, record)