Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] _RichMonitor performance regression #3215

Open
HolyWu opened this issue Oct 6, 2024 · 0 comments
Open

🐛 [Bug] _RichMonitor performance regression #3215

HolyWu opened this issue Oct 6, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Oct 6, 2024

Bug Description

When debug=True and rich module is available, _RichMonitor causes the engine building time significantly longer, and the performance of the engine is also decreased. Using _ASCIIMonitor (i.e. rich module is unavailable) is a lot better than _RichMonitor, but still a bit inferior to not using progress monitor.

To Reproduce

from __future__ import annotations

import os
import tempfile

import numpy as np
import torch
import torch_tensorrt
import torchvision.models as models

# uncomment to disable progress monitor
# os.environ["CI_BUILD"] = "1"

times = 100


@torch.inference_mode()
def benchmark(model: torch.nn.Module, inputs: list[torch.Tensor]) -> np.ndarray:
    # Warm up
    for i in range(3):
        model(inputs[i])

    torch.cuda.synchronize()

    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]

    for i in range(times):
        torch.cuda._sleep(1_000_000)
        start_events[i].record()
        model(inputs[i])
        end_events[i].record()

    torch.cuda.synchronize()
    timings = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    return np.array(timings)


if __name__ == "__main__":
    torch.manual_seed(12345)
    device = torch.device("cuda", 0)
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2).eval().to(device).half()

    inputs = [torch_tensorrt.Input(shape=(1, 3, 224, 224), dtype=torch.half)]

    timing_cache_path = os.path.join(tempfile.gettempdir(), "nonexistent_timing.cache")
    if os.path.isfile(timing_cache_path):
        os.remove(timing_cache_path)

    trt_model = torch_tensorrt.compile(
        model,
        ir="dynamo",
        inputs=inputs,
        device=device,
        enabled_precisions={torch.half},
        debug=True,
        min_block_size=1,
        timing_cache_path=timing_cache_path,
    )

    inputs = [torch.rand((1, 3, 224, 224), dtype=torch.half, device=device) for _ in range(times)]
    torch_timings = benchmark(model, inputs)
    trt_timings = benchmark(trt_model, inputs)

    print("")
    print("Torch timings:")
    print(
        f"Min={torch_timings.min()} milliseconds, "
        f"Mean={torch_timings.mean()} milliseconds, "
        f"Max={torch_timings.max()} milliseconds"
    )

    print("")
    print("TRT timings:")
    print(
        f"Min={trt_timings.min()} milliseconds, "
        f"Mean={trt_timings.mean()} milliseconds, "
        f"Max={trt_timings.max()} milliseconds"
    )

_RichMonitor

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:10:42.736987

Torch timings:
Min=2.6961920261383057 milliseconds, Mean=5.373901104927063 milliseconds, Max=15.220735549926758 milliseconds

TRT timings:
Min=3.9034879207611084 milliseconds, Mean=4.16745824098587 milliseconds, Max=4.49945592880249 milliseconds

_ASCIIMonitor

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:02:11.301710

Torch timings:
Min=2.4842240810394287 milliseconds, Mean=5.146092460155487 milliseconds, Max=10.242079734802246 milliseconds

TRT timings:
Min=2.702336072921753 milliseconds, Mean=3.2247164821624756 milliseconds, Max=6.435840129852295 milliseconds

Progress monitor disabled

INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:01:34.190098

Torch timings:
Min=2.6552319526672363 milliseconds, Mean=4.810590088367462 milliseconds, Max=9.890815734863281 milliseconds

TRT timings:
Min=2.5118720531463623 milliseconds, Mean=2.995586881637573 milliseconds, Max=4.638751983642578 milliseconds

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0.dev20241004+cu124
  • PyTorch Version (e.g. 1.0): 2.6.0.dev20241003+cu124
  • CPU Architecture: x64
  • OS (e.g., Linux): Windows 11 and Ubuntu 24.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12
  • CUDA version: 12.4
  • GPU models and configuration: RTX 4060 Ti
  • Any other relevant information:
@HolyWu HolyWu added the bug Something isn't working label Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant