diff --git a/.gitignore b/.gitignore index dd49b40d..b87a0b7d 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ cython_debug/ .vscode/ *.ipynb runs/ -sweeps/ \ No newline at end of file +sweeps/ +data/ +version.txt diff --git a/examples/text_generation_inference_llama.yaml b/examples/text_generation_inference_llama.yaml new file mode 100644 index 00000000..af4b5bba --- /dev/null +++ b/examples/text_generation_inference_llama.yaml @@ -0,0 +1,34 @@ +defaults: + - backend: text-generation-inference # default backend + - benchmark: inference # default benchmark + - experiment # inheriting experiment schema + - _self_ # for hydra 1.1 compatibility + - override hydra/job_logging: colorlog # colorful logging + - override hydra/hydra_logging: colorlog # colorful logging + +hydra: + run: + dir: runs/${experiment_name} + sweep: + dir: sweeps/${experiment_name} + job: + chdir: true + env_set: + CUDA_VISIBLE_DEVICES: 0,1 + +experiment_name: text_generation_inference +model: NousResearch/Llama-2-7b-hf +device: cuda + +backend: + no_weights: true + initial_isolation_check: false + continous_isolation_check: false + torch_dtype: float16 + +benchmark: + input_shapes: + batch_size: 32 + sequence_length: 128 + + new_tokens: 1000 diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 457de850..774223e0 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -154,6 +154,17 @@ def seed(self) -> None: # torch.backends.cudnn.deterministic = True # same as above # torch.backends.cudnn.benchmark = False # might reduce performance + def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]: + if self.is_diffusion_pipeline(): + # diffusion pipelines expect a list of strings as input + return input + else: + # models expect tensors on the target device as input + for key, value in input.items(): + input[key] = value.to(self.device) + + return input + # compiling in openvino requires input shapes def prepare_for_inference(self, input_shapes: Dict[str, int]) -> Dict[str, Any]: pass @@ -162,10 +173,10 @@ def prepare_for_inference(self, input_shapes: Dict[str, int]) -> Dict[str, Any]: def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]: pass - def forward(self, input: Dict[str, Any], **kwargs) -> "ModelOutput": + def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput": return self.pretrained_model(**input, **kwargs) - def generate(self, input: Dict[str, Any], **kwargs) -> "ModelOutput": + def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput": return self.pretrained_model.generate(**input, **kwargs) def train( @@ -193,14 +204,14 @@ def model_shapes(self) -> Dict[str, int]: def delete_pretrained_model(self) -> None: if hasattr(self, "pretrained_model"): + LOGGER.info("\t+ Deleting pretrained model") del self.pretrained_model - gc.collect() def delete_model_cache(self) -> None: LOGGER.info("\t+ Deleting model cache") - model_cache_path = f"models/{self.model}".replace("/", "--") - model_cache_path = os.path.join(os.path.expanduser("~/.cache/huggingface/hub"), model_cache_path) + model_cache_folder = f"models/{self.model}".replace("/", "--") + model_cache_path = os.path.join(os.path.expanduser("~/.cache/huggingface/hub"), model_cache_folder) shutil.rmtree(model_cache_path, ignore_errors=True) def clean(self) -> None: diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index 2fe1f6f2..3d1f2466 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -21,7 +21,6 @@ LOGGER = getLogger("pytorch") - class PyTorchBackend(Backend[PyTorchConfig]): NAME: str = "pytorch" @@ -223,21 +222,21 @@ def prepare_for_profiling(self, input_names: List[str]) -> None: LOGGER.info("\t+ Wrapping model with FXProfilingWrapper") self.pretrained_model = FXProfilingWrapper(self.pretrained_model) - def forward(self, input: Dict[str, Any], **kwargs) -> "ModelOutput": + def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput": if self.is_diffusion_pipeline(): - return super().forward(input, **kwargs) + return super().forward(input, kwargs) else: - # TODO: autocast as whole can be managed by one config/kwargs + # TODO: autocast as whole can be managed by one config/kwargs ? with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.config.amp_autocast): - return super().forward(input, **kwargs) + return super().forward(input, kwargs) - def generate(self, input: Dict[str, torch.Tensor], **kwargs) -> "ModelOutput": + def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput": if self.is_diffusion_pipeline(): - return super().generate(input, **kwargs) + return super().generate(input, kwargs) else: - # TODO: autocast as whole can be managed by one config/kwargs + # TODO: autocast as whole can be managed by one config/kwargs ? with torch.autocast(device_type=self.device.type, dtype=self.amp_dtype, enabled=self.config.amp_autocast): - return super().generate(input, **kwargs) + return super().generate(input, kwargs) @record_if_available def train( diff --git a/optimum_benchmark/backends/text_generation_inference/backend.py b/optimum_benchmark/backends/text_generation_inference/backend.py new file mode 100644 index 00000000..948eaaa4 --- /dev/null +++ b/optimum_benchmark/backends/text_generation_inference/backend.py @@ -0,0 +1,209 @@ +import os +from concurrent.futures import ThreadPoolExecutor +from logging import getLogger +from typing import TYPE_CHECKING, Any, Dict, List + +import torch +from accelerate import init_empty_weights +from huggingface_hub import InferenceClient +from transformers import GenerationConfig + +if TYPE_CHECKING: + from huggingface_hub.inference._text_generation import TextGenerationResponse + +import docker +import docker.errors +import docker.types + +from ..base import Backend +from ..pytorch.utils import randomize_weights +from .config import TGIConfig + +# bachend logger +LOGGER = getLogger("text-generation-inference") + + +class TGIBackend(Backend[TGIConfig]): + NAME: str = "text-generation-inference" + + def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]): + super().__init__(model, task, device, hub_kwargs) + self.validate_task() + + automodel = self.automodel_class.__name__ + LOGGER.info(f"\t+ Infered AutoModel class {automodel} for task {self.task} and model_type {self.model_type}") + + def validate_task(self) -> None: + if self.task not in ["text-generation", "text2text-generation"]: + raise NotImplementedError(f"TGI does not support task {self.task}") + + def configure(self, config: TGIConfig) -> None: + super().configure(config) + self.config = config + + if self.config.no_weights: + # creates dummy model + self.load_model_from_config() + self.save_model_snapshot() + else: + self.load_model_from_pretrained() + self.delete_pretrained_model() + + LOGGER.info("\t+ Modifying generation config") + self.modify_generation_config() + + LOGGER.info("\t+ Starting Docker client") + self.docker_client = docker.from_env() + + try: + LOGGER.info("\t+ Checking if TGI image exists") + self.docker_client.images.get(f"{self.config.image}:{self.config.version}") + except docker.errors.ImageNotFound: + LOGGER.info("\t+ TGI image not found, pulling it") + self.docker_client.images.pull(f"{self.config.image}:{self.config.version}") + + LOGGER.info("\t+ Building TGI command") + self.command = [ + "--model-id", + self.model, + "--revision", + self.hub_kwargs["revision"], + ] + + if self.config.quantization is not None: + self.command.extend(["--quantize", self.config.quantization]) + if self.config.torch_dtype is not None: + self.command.extend(["--torch-dtype", self.config.torch_dtype]) + if self.hub_kwargs.get("trust_remote_code", False): + self.command.append("--trust-remote-code") + if self.config.disable_custom_kernels: + self.command.append("--disable-custom-kernels") + + if self.device.type == "cuda": + device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", self.device.index or 0) + LOGGER.info(f"\t+ Starting TGI container on CUDA device(s): {device_ids}") + device_requests = [docker.types.DeviceRequest(device_ids=[str(device_ids)], capabilities=[["gpu"]])] + else: + LOGGER.info("\t+ Starting TGI container on CPU device") + device_requests = None + + self.tgi_container = self.docker_client.containers.run( + image=f"{self.config.image}:{self.config.version}", + command=self.command, + shm_size=self.config.shm_size, + volumes={self.config.volume: {"bind": "/data", "mode": "rw"}}, + ports={"80/tcp": (self.config.address, self.config.port)}, + device_requests=device_requests, + detach=True, + ) + + LOGGER.info("\t+ Waiting for TGI server to be ready") + for line in self.tgi_container.logs(stream=True): + tgi_log = line.decode("utf-8").strip() + if not tgi_log: + continue + elif "Connected" in tgi_log: + LOGGER.info("\t+ TGI server is ready") + break + else: + LOGGER.info(f"\t {tgi_log}") + + LOGGER.info("\t+ Creating InferenceClient") + self.client = InferenceClient(model=f"http://{self.config.address}:{self.config.port}") + + def load_model_from_config(self) -> None: + LOGGER.info("\t+ Initializing empty weights model on device: meta") + with init_empty_weights(): + self.pretrained_model = self.automodel_class.from_config( + config=self.pretrained_config, + torch_dtype=getattr(torch, self.config.torch_dtype), + trust_remote_code=self.hub_kwargs.get("trust_remote_code", False), + ) + # could add model dispatching to accelerate saving and support bigger models + LOGGER.info(f"\t+ Materializing model on device: {self.device}") + self.pretrained_model.to_empty(device=self.device) + LOGGER.info("\t+ Randomizing model weights") + randomize_weights(self.pretrained_model) + LOGGER.info("\t+ Tying weights") + self.pretrained_model.tie_weights() + + @property + def model_snapshot_path(self) -> str: + model_cache_folder = f"models/{self.model}".replace("/", "--") + model_cache_path = f"{self.config.volume}/{model_cache_folder}" + snapshot_ref = open(f"{model_cache_path}/refs/{self.hub_kwargs.get('revision', 'main')}", "r").read().strip() + return f"{model_cache_path}/snapshots/{snapshot_ref}" + + def save_model_snapshot(self) -> None: + LOGGER.info("\t+ Saving pretrained model snapshot") + self.pretrained_model.save_pretrained(self.model_snapshot_path, safe_serialization=True) + + def load_model_from_pretrained(self) -> None: + LOGGER.info("\t+ Downloading pretrained model") + with init_empty_weights(): + self.pretrained_model = self.automodel_class.from_pretrained(self.model, **self.hub_kwargs) + + def modify_generation_config(self) -> None: + # this should, theorically, make the generated output's sequence length fully controlled by max_new_tokens + # instead of stopping at the first eos_token_id/pad_token_id + generation_config = GenerationConfig.from_pretrained(self.model, **self.hub_kwargs) + generation_config.eos_token_id = -100 + generation_config.pad_token_id = -101 + generation_config.save_pretrained(self.model_snapshot_path) + + def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]: + return {"prompt": self.pretrained_processor.batch_decode(input["input_ids"].tolist())} + + def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> List["TextGenerationResponse"]: + output = [] + with ThreadPoolExecutor(max_workers=len(input["prompt"])) as executor: + futures = [ + executor.submit( + self.client.text_generation, + decoder_input_details=True, + prompt=input["prompt"][i], + max_new_tokens=1, + details=True, + ) + for i in range(len(input["prompt"])) + ] + for future in futures: + output.append(future.result()) + + return output + + def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> List["TextGenerationResponse"]: + output = [] + with ThreadPoolExecutor(max_workers=len(input["prompt"])) as executor: + futures = [ + executor.submit( + self.client.text_generation, + max_new_tokens=kwargs["max_new_tokens"], + do_sample=kwargs["do_sample"], + prompt=input["prompt"][i], + details=True, + ) + for i in range(len(input["prompt"])) + ] + for i in range(len(input["prompt"])): + output.append(futures[i].result()) + if len(output[-1].details["tokens"]) < kwargs["max_new_tokens"]: + LOGGER.warning( + f"\t+ Generated {len(output[-1].details['tokens'])} tokens instead of {kwargs['max_new_tokens']}" + " tokens. Benchmark results might be inaccurate." + ) + + return output + + def clean(self) -> None: + super().clean() + + if hasattr(self, "tgi_container"): + LOGGER.info("\t+ Stoping TGI container") + self.tgi_container.stop() + LOGGER.info("\t+ Waiting for TGI container to stop") + self.tgi_container.wait() + + if hasattr(self, "docker_client"): + LOGGER.info("\t+ Closing docker client") + self.docker_client.close() diff --git a/optimum_benchmark/backends/text_generation_inference/config.py b/optimum_benchmark/backends/text_generation_inference/config.py new file mode 100644 index 00000000..fb9ac4c6 --- /dev/null +++ b/optimum_benchmark/backends/text_generation_inference/config.py @@ -0,0 +1,34 @@ +import os +from dataclasses import dataclass +from typing import Optional + +from ..config import BackendConfig + + +@dataclass +class TGIConfig(BackendConfig): + name: str = "tgi" + version: str = "1.0.3" + _target_: str = "optimum_benchmark.backends.text_generation_inference.backend.TGIBackend" + + # server options + image: str = "ghcr.io/huggingface/text-generation-inference" + volume: str = f"{os.path.expanduser('~')}/.cache/huggingface/hub" + shm_size: str = "1g" + address: str = "127.0.0.1" + port: int = 1111 + + # torch options + no_weights: bool = False # True, False + torch_dtype: Optional[str] = None # None, float32, float16, bfloat16 + # optimization options + disable_custom_kernels: bool = False # True, False + # quantization options + quantization: Optional[str] = None # None, bitsandbytes-nf4, bitsandbytes-fp4 + + def __post_init__(self): + super().__post_init__() + + if self.torch_dtype is not None: + if self.torch_dtype not in ["float32", "float16", "bfloat16"]: + raise ValueError(f"Invalid value for torh_dtype: {self.torch_dtype}") diff --git a/optimum_benchmark/benchmarks/inference/benchmark.py b/optimum_benchmark/benchmarks/inference/benchmark.py index 8b57fc99..00d3ce0f 100644 --- a/optimum_benchmark/benchmarks/inference/benchmark.py +++ b/optimum_benchmark/benchmarks/inference/benchmark.py @@ -57,24 +57,21 @@ def run(self, backend: "Backend") -> None: def run_forward_tracking(self, backend: "Backend") -> None: forward_input = self.input_generator.generate(mode="forward") - # TODO: can be handled by the backend later - for key, value in forward_input.items(): - if key == "prompt": - continue - forward_input[key] = value.to(backend.device) + LOGGER.info("\t+ Preparing input for the forward pass") + forward_input = backend.prepare_input(forward_input) # for backends that require compilation with static shapes backend.prepare_for_inference(input_shapes=self.config.input_shapes) LOGGER.info("\t+ Warming up the forward pass") for _ in range(self.config.warmup_runs): - _ = backend.forward(forward_input, **self.config.forward_kwargs) + _ = backend.forward(forward_input, self.config.forward_kwargs) LOGGER.info("\t+ Tracking forward pass latency and throughput") latency_tracker = LatencyTracker(device=backend.device, backend=backend.NAME) while sum(self.forward_latencies) < self.config.duration: with latency_tracker.track(): - _ = backend.forward(forward_input, **self.config.forward_kwargs) + _ = backend.forward(forward_input, self.config.forward_kwargs) self.forward_latencies = latency_tracker.get_latencies() LOGGER.info(f"\t+ Forward pass latency: {self.forward_latency:.2e} (s)") LOGGER.info(f"\t+ Forward pass throughput: {self.forward_throughput:.2f} (samples/s)") @@ -83,7 +80,7 @@ def run_forward_tracking(self, backend: "Backend") -> None: LOGGER.info("\t+ Tracking forward pass peak memory") memory_tracker = MemoryTracker(device=backend.device) with memory_tracker.track(interval=self.forward_latency / 10): - _ = backend.forward(forward_input) + _ = backend.forward(forward_input, self.config.forward_kwargs) self.forward_peak_memory = memory_tracker.get_peak_memory() LOGGER.info(f"\t+ Forward pass peak memory: {self.forward_peak_memory} (MB)") @@ -93,7 +90,7 @@ def run_forward_tracking(self, backend: "Backend") -> None: energy_tracker = EnergyTracker() with energy_tracker.track(interval=1, file_prefix="forward"): while energy_tracker.get_elapsed_time() < self.config.duration: - _ = backend.forward(forward_input, **self.config.forward_kwargs) + _ = backend.forward(forward_input, self.config.forward_kwargs) num_forward_passes += 1 num_forward_samples = num_forward_passes * self.config.input_shapes["batch_size"] @@ -110,20 +107,17 @@ def run_forward_tracking(self, backend: "Backend") -> None: def run_generate_tracking(self, backend: "Backend") -> None: generate_input = self.input_generator.generate(mode="generate") - # TODO: can be handled by the backend later - for key, value in generate_input.items(): - if key == "prompt": - continue - generate_input[key] = value.to(backend.device) + LOGGER.info("\t+ Preparing input for the generation pass") + generate_input = backend.prepare_input(generate_input) LOGGER.info("\t+ Warming up the generation pass") - _ = backend.generate(input=generate_input, **self.config.generate_kwargs) + _ = backend.generate(generate_input, self.config.generate_kwargs) LOGGER.info("\t+ Tracking generation latency and throughput") latency_tracker = LatencyTracker(device=backend.device, backend=backend.NAME) while sum(self.generate_latencies) < self.config.duration: with latency_tracker.track(): - _ = backend.generate(generate_input, **self.config.generate_kwargs) + _ = backend.generate(generate_input, self.config.generate_kwargs) self.generate_latencies = latency_tracker.get_latencies() LOGGER.info(f"\t+ Generation pass latency: {self.generate_latency:.2e} (s)") LOGGER.info(f"\t+ Generation pass throughput: {self.generate_throughput:.2f} (tokens/s)") @@ -132,7 +126,7 @@ def run_generate_tracking(self, backend: "Backend") -> None: LOGGER.info("\t+ Tracking generation pass peak memory") memory_tracker = MemoryTracker(device=backend.device) with memory_tracker.track(interval=self.generate_latency / 10): - _ = backend.generate(generate_input, **self.config.generate_kwargs) + _ = backend.generate(generate_input, self.config.generate_kwargs) self.generate_peak_memory = memory_tracker.get_peak_memory() LOGGER.info(f"\t+ Generation pass peak memory: {self.generate_peak_memory} (MB)") @@ -142,7 +136,7 @@ def run_generate_tracking(self, backend: "Backend") -> None: energy_tracker = EnergyTracker() with energy_tracker.track(interval=1, file_prefix="generate"): while energy_tracker.get_elapsed_time() < self.config.duration: - _ = backend.generate(generate_input, **self.config.generate_kwargs) + _ = backend.generate(generate_input, self.config.generate_kwargs) num_generate_passes += 1 num_generated_tokens = ( diff --git a/optimum_benchmark/experiment.py b/optimum_benchmark/experiment.py index 7d1fbeb0..6941c66a 100644 --- a/optimum_benchmark/experiment.py +++ b/optimum_benchmark/experiment.py @@ -13,6 +13,7 @@ from .backends.onnxruntime.config import ORTConfig from .backends.openvino.config import OVConfig from .backends.pytorch.config import PyTorchConfig +from .backends.text_generation_inference.config import TGIConfig from .benchmarks.inference.config import InferenceConfig from .benchmarks.training.config import TrainingConfig from .env_utils import get_cpu, get_cpu_ram_mb, get_gpus @@ -102,10 +103,11 @@ def __post_init__(self) -> None: # Register configurations cs = ConfigStore.instance() cs.store(name="experiment", node=ExperimentConfig) +cs.store(group="backend", name="openvino", node=OVConfig) cs.store(group="backend", name="pytorch", node=PyTorchConfig) cs.store(group="backend", name="onnxruntime", node=ORTConfig) -cs.store(group="backend", name="openvino", node=OVConfig) cs.store(group="backend", name="neural_compressor", node=INCConfig) +cs.store(group="backend", name="text-generation-inference", node=TGIConfig) cs.store(group="benchmark", name="inference", node=InferenceConfig) cs.store(group="benchmark", name="training", node=TrainingConfig) @@ -131,6 +133,7 @@ def run_experiment(experiment: DictConfig) -> None: backend.configure(experiment.backend) except Exception as e: LOGGER.error("Error during backend configuration: %s", e) + backend.clean() raise e # Allocate requested benchmark @@ -140,6 +143,7 @@ def run_experiment(experiment: DictConfig) -> None: benchmark.configure(experiment.benchmark) except Exception as e: LOGGER.error("Error during benchmark configuration: %s", e) + backend.clean() raise e try: diff --git a/optimum_benchmark/generators/input_generator.py b/optimum_benchmark/generators/input_generator.py index 56f97ea9..a37e4fe5 100644 --- a/optimum_benchmark/generators/input_generator.py +++ b/optimum_benchmark/generators/input_generator.py @@ -4,14 +4,8 @@ if TYPE_CHECKING: from transformers import PretrainedConfig -from optimum_benchmark.generators.model_type_generator import ( - SUPPURTED_MODEL_TYPES, - ModelTypeGenerator, -) -from optimum_benchmark.generators.task_generator import ( - TASKS_TO_GENERATORS, - TaskGenerator, -) +from ..generators.model_type_generator import SUPPURTED_MODEL_TYPES, ModelTypeGenerator +from ..generators.task_generator import TASKS_TO_GENERATORS, TaskGenerator LOGGER = getLogger("input_generator") diff --git a/optimum_benchmark/generators/model_type_generator.py b/optimum_benchmark/generators/model_type_generator.py index 4ae1f93c..f8facc3a 100644 --- a/optimum_benchmark/generators/model_type_generator.py +++ b/optimum_benchmark/generators/model_type_generator.py @@ -5,9 +5,14 @@ from optimum.exporters.tasks import TasksManager from transformers import PretrainedConfig +from ..import_utils import is_onnx_available + LOGGER = getLogger("model_type_generator") -SUPPURTED_MODEL_TYPES: List[str] = list(TasksManager._SUPPORTED_MODEL_TYPE.keys()) +EXPORTER = "onnx" # used for its configs as input generators +SUPPURTED_MODEL_TYPES: List[str] = ( + list(TasksManager._SUPPORTED_MODEL_TYPE.keys()) if is_onnx_available() else [] +) # should be empty if onnx is not available class ModelTypeGenerator: @@ -26,7 +31,7 @@ def __init__( self.onnx_config = TasksManager.get_exporter_config_constructor( task=task, - exporter="onnx", + exporter=EXPORTER, model_type=model_type, )(pretrained_config) diff --git a/optimum_benchmark/import_utils.py b/optimum_benchmark/import_utils.py index 0c988ba8..8e70bd4b 100644 --- a/optimum_benchmark/import_utils.py +++ b/optimum_benchmark/import_utils.py @@ -6,6 +6,7 @@ _diffusers_available = importlib.util.find_spec("diffusers") is not None _optimum_available = importlib.util.find_spec("optimum") is not None _torch_available = importlib.util.find_spec("torch") is not None +_onnx_available = importlib.util.find_spec("onnx") is not None _py3nvml_available = importlib.util.find_spec("py3nvml") is not None _torch_distributed_available = importlib.util.find_spec("torch.distributed") _onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None @@ -13,6 +14,18 @@ _neural_compressor_available = importlib.util.find_spec("neural_compressor") is not None +def is_onnx_available(): + return _onnx_available + + +def is_optimum_available(): + return _optimum_available + + +def is_onnxruntime_available(): + return _onnxruntime_available + + def is_py3nvml_available(): return _py3nvml_available diff --git a/tgi_requirements.txt b/tgi_requirements.txt new file mode 100644 index 00000000..989ba1a4 --- /dev/null +++ b/tgi_requirements.txt @@ -0,0 +1,26 @@ +# HF ecosystem +# transformers +git+https://github.com/huggingface/transformers.git +# optimum +git+https://github.com/huggingface/optimum.git +# accelerate +git+https://github.com/huggingface/accelerate.git + +# hydra +omegaconf==2.3.0 +hydra-core==1.3.2 +hydra_colorlog==1.2.0 +hydra-joblib-launcher==1.2.0 + +# system +docker==6.1.3 +psutil==5.9.0 +py3nvml==0.2.7 +codecarbon==2.3.1 + +# reporting +flatten_dict +matplotlib +seaborn +pandas +rich \ No newline at end of file