Skip to content

Commit

Permalink
refactor: API uses now process controller
Browse files Browse the repository at this point in the history
  • Loading branch information
le1nux committed Jan 5, 2025
1 parent a05bd9f commit ac8f74e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 71 deletions.
157 changes: 88 additions & 69 deletions src/modalities/api.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
#!/usr/bin/env python

import multiprocessing as mp
import os
from enum import Enum
from pathlib import Path
from typing import Optional

from modalities.dataloader.preprocessing.tokenization.create_packed_data import PackedDataGenerator
from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import (
EmbeddedStreamData,
join_embedded_stream_data,
)
from modalities.dataloader.preprocessing.tokenization.tokenization_processes import (
ProcessFactory,
ProgressLoggingWorker,
get_required_num_of_bytes_to_repr,
)
from modalities.utils.logging import get_logger
from pydantic import FilePath

import modalities.inference.inference as inference
from modalities.checkpointing.checkpoint_conversion import CheckpointConversion
from modalities.config.component_factory import ComponentFactory
from modalities.config.instantiation_models import PackedDatasetComponentsInstantiationModel
from modalities.config.instantiation_models import TokenizationInstantiationModel
from modalities.dataloader.preprocessing.indexation.create_index import IndexGenerator
from modalities.dataloader.preprocessing.queued_processing.process_controller import PipelineStep, ProcessController
from modalities.dataloader.preprocessing.queued_processing.queued_processing import Processor
from modalities.dataloader.preprocessing.tokenization.embedded_stream_data import (
EmbeddedStreamData,
join_embedded_stream_data,
)
from modalities.dataloader.preprocessing.tokenization.large_file_lines_reader import LocalLargeFileLinesReader
from modalities.dataloader.preprocessing.tokenization.tokenization_strategies import (
ProcessingStrategyFactory,
WorkerTypes,
populate_reader_q,
)
from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
import multiprocessing as mp
import shutil

from enum import Enum
from modalities.utils.logging import get_logger


class FileExistencePolicy(Enum):
Expand Down Expand Up @@ -122,70 +120,91 @@ def pack_encoded_data(config_dict: dict):
# ResolverRegistry to work dynamically with any type-hinted config object from config.py.
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
instantion_model: PackedDatasetComponentsInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=PackedDatasetComponentsInstantiationModel
instantion_model: TokenizationInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=TokenizationInstantiationModel
)

# build the queues
reader_q, tokenizer_q, writer_q, logging_message_q = ProcessFactory.get_process_queues(
reader_q, tokenizer_q, writer_q, logging_message_q = ProcessingStrategyFactory.get_process_queues(
writer_q_maxsize=instantion_model.writer_q_maxsize, tokenizer_q_maxsize=instantion_model.tokenizer_q_maxsize
)

# build the workers
stop_event = mp.Event()
token_size_in_bytes = get_required_num_of_bytes_to_repr(
instantion_model.tokenizer_worker_settings.tokenizer_settings.tokenizer.vocab_size
)

reader_workers = ProcessFactory.get_reader_workers(
rw_settings=instantion_model.reader_worker_settings,
reader_q=reader_q,
tokenizer_q=tokenizer_q,
logging_message_q=logging_message_q,
stop_event=stop_event,
)

tokenizer_workers = ProcessFactory.get_tokenizer_workers(
tw_settings=instantion_model.tokenizer_worker_settings,
tokenizer_q=tokenizer_q,
writer_q=writer_q,
logging_message_q=logging_message_q,
token_size_in_bytes=token_size_in_bytes,
tokenizer_q_key = "tokenizer_q"
writer_q_key = "writer_q"
logging_message_q_key = "logging_message_q"

reader_settings = instantion_model.reader_worker_settings.reader_settings

reader_workers = [
Processor(
in_q=reader_q,
out_qs={tokenizer_q_key: tokenizer_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_reader_strategy(
reader_settings, tokenizer_q_key=tokenizer_q_key, logging_message_q_key=logging_message_q_key
),
process_type=WorkerTypes.READER,
process_id=i,
logging_message_q_key=logging_message_q_key,
stop_event=stop_event,
)
for i in range(instantion_model.reader_worker_settings.num_workers)
]

tokenizer_workers = [
Processor(
in_q=tokenizer_q,
out_qs={writer_q_key: writer_q, logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_tokenizer_strategy(
tokenizer_settings=instantion_model.tokenizer_worker_settings.tokenizer_settings,
writer_q_key=writer_q_key,
logging_message_q_key=logging_message_q_key,
),
process_type=WorkerTypes.TOKENIZER,
process_id=i,
logging_message_q_key=logging_message_q_key,
stop_event=stop_event,
)
for i in range(instantion_model.tokenizer_worker_settings.num_workers)
]

writer_worker = Processor(
in_q=writer_q,
out_qs={logging_message_q_key: logging_message_q},
in_q_timeout=instantion_model.in_q_timeout,
out_q_timeout=instantion_model.out_q_timeout,
strategy=ProcessingStrategyFactory.get_writing_strategy(
ww_settings=instantion_model.writer_worker_settings, logging_message_q_key=logging_message_q_key
),
process_type=WorkerTypes.WRITER,
process_id=0,
logging_message_q_key=logging_message_q_key,
stop_event=stop_event,
)

writer_worker = ProcessFactory.get_writer_worker(
writer_q=writer_q,
logging_message_q=logging_message_q,
token_size_in_bytes=token_size_in_bytes,
ww_settings=instantion_model.writer_worker_settings,
stop_event=stop_event,
)

progress_logging_worker = ProgressLoggingWorker(
logging_message_q=logging_message_q,
reader_q=reader_q,
tokenizer_q=tokenizer_q,
writer_q=writer_q,
total_num_samples=instantion_model.num_samples,
stop_event=stop_event,
logging_interval=instantion_model.logging_interval,
)

generator = PackedDataGenerator(
reader_workers=reader_workers,
tokenizer_workers=tokenizer_workers,
writer_worker=writer_worker,
progress_logging_worker=progress_logging_worker,
reader_q=reader_q,
tokenizer_q=tokenizer_q,
writer_q=writer_q,
logging_message_q=logging_message_q,
index_start=instantion_model.index_start,
num_samples=instantion_model.num_samples,
batch_size=instantion_model.batch_size,
)
generator.run()
pipeline_steps = [
PipelineStep(name="reading", input_queue=reader_q, processors=reader_workers),
PipelineStep(name="tokenizing", input_queue=tokenizer_q, processors=tokenizer_workers),
PipelineStep(name="writing", input_queue=writer_q, processors=[writer_worker]),
]

def populate():
populate_reader_q(
reader_q=reader_q,
index_start=instantion_model.index_start,
num_samples=instantion_model.num_samples,
num_reader_processes=instantion_model.reader_worker_settings.num_workers,
batch_size=instantion_model.batch_size,
)

process_controller = ProcessController(pipeline_steps=pipeline_steps, populate_jobs=populate)
process_controller.run()


def merge_packed_data_files(src_paths: list[Path], target_path: Path):
Expand Down
4 changes: 2 additions & 2 deletions src/modalities/config/component_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _build_component(
# instantiate component config
component_key = current_component_config["component_key"]
variant_key = current_component_config["variant_key"]
current_component_config = self._instantiate_component_config(
current_component_config = self.instantiate_component_config(
component_key=component_key,
variant_key=variant_key,
config_dict=materialized_component_config["config"],
Expand Down Expand Up @@ -139,7 +139,7 @@ def _is_reference_config(config_dict: dict) -> bool:
# TODO instead of field checks, we should introduce an enum for the config type.
return {"instance_key", "pass_type"} == config_dict.keys()

def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel:
def instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel:
component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key)
self._assert_valid_config_keys(
component_key=component_key,
Expand Down

0 comments on commit ac8f74e

Please sign in to comment.