Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Error when serving Torch-TensorRT JIT model to Nvidia-Triton #3248

Open
zmy1116 opened this issue Oct 18, 2024 · 6 comments · Fixed by #3292
Open

🐛 [Bug] Error when serving Torch-TensorRT JIT model to Nvidia-Triton #3248

zmy1116 opened this issue Oct 18, 2024 · 6 comments · Fixed by #3292
Labels
bug Something isn't working

Comments

@zmy1116
Copy link

zmy1116 commented Oct 18, 2024

Bug Description

I'm trying to serve torch-tensorrt optimized model to Nvidia Triton server based on the provided tutorial
https://pytorch.org/TensorRT/tutorials/serving_torch_tensorrt_with_triton.html

First the provided script to generate optimized model does not work. I tweak a bit got that to work. Then when I try to perform inference using Triton server, I got the error
ERROR: [Torch-TensorRT] - IExecutionContext::enqueueV3: Error Code 1: Cuda Runtime (invalid resource handle)

To Reproduce

So the pytorch page provide the followoing script to save the optimized jit model

import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")

# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    enabled_precisions= { torch.half} # Run with FP32
)

# Save the model
torch.jit.save(trt_model, "model.pt")

When I run this script, I got the error AttributeError: 'GraphModule' object has no attribute 'save

To resolve this I tried the following 2 ways

  1. Save model with torch_tensorrt.save
    torch.jit.save(trt_model._run_on_acc_0, "/home/ubuntu/model.pt")

  2. compile a traced jit model directly

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")
model_jit = torch.jit.trace(model, [torch.rand(1,3,224,224).cuda()])
trt_model = torch_tensorrt.compile(model,
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    enabled_precisions= { torch.half} # Run with FP32
)

I confirm both methods create jit model correctly.

I then put model in folder with the same structure the tutorial provides. Launch the triton server. The triton server launch successfully.

I1018 03:38:23.657822 1 server.cc:674] 
+----------+---------+--------+
| Model    | Version | Status |
+----------+---------+--------+
| resnet50 | 1       | READY  |
+----------+---------+--------+

I1018 03:38:23.886797 1 metrics.cc:877] "Collecting metrics for GPU 0: NVIDIA L4"
I1018 03:38:23.886839 1 metrics.cc:877] "Collecting metrics for GPU 1: NVIDIA L4"
I1018 03:38:23.886852 1 metrics.cc:877] "Collecting metrics for GPU 2: NVIDIA L4"
I1018 03:38:23.886864 1 metrics.cc:877] "Collecting metrics for GPU 3: NVIDIA L4"
I1018 03:38:23.886873 1 metrics.cc:877] "Collecting metrics for GPU 4: NVIDIA L4"
I1018 03:38:23.886882 1 metrics.cc:877] "Collecting metrics for GPU 5: NVIDIA L4"
I1018 03:38:23.886893 1 metrics.cc:877] "Collecting metrics for GPU 6: NVIDIA L4"
I1018 03:38:23.886901 1 metrics.cc:877] "Collecting metrics for GPU 7: NVIDIA L4"
I1018 03:38:23.916949 1 metrics.cc:770] "Collecting CPU metrics"
I1018 03:38:23.917116 1 tritonserver.cc:2598] 

+----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Option                           | Value                                                                                                                                                                                                           |
+----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| server_id                        | triton                                                                                                                                                                                                          |
| server_version                   | 2.50.0                                                                                                                                                                                                          |
| server_extensions                | classification sequence model_repository model_repository(unload_dependents) schedule_policy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data parameters statistics trace logging |
| model_repository_path[0]         | /home/ubuntu/model_repository_4                                                                                                                                                                                 |
| model_control_mode               | MODE_NONE                                                                                                                                                                                                       |
| strict_model_config              | 0                                                                                                                                                                                                               |
| model_config_name                |                                                                                                                                                                                                                 |
| rate_limit                       | OFF                                                                                                                                                                                                             |
| pinned_memory_pool_byte_size     | 268435456                                                                                                                                                                                                       |
| cuda_memory_pool_byte_size{0}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{1}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{2}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{3}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{4}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{5}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{6}    | 67108864                                                                                                                                                                                                        |
| cuda_memory_pool_byte_size{7}    | 67108864                                                                                                                                                                                                        |
| min_supported_compute_capability | 6.0                                                                                                                                                                                                             |
| strict_readiness                 | 1                                                                                                                                                                                                               |
| exit_timeout                     | 30                                                                                                                                                                                                              |
| cache_enabled                    | 0                                                                                                                                                                                                               |
+----------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

However, when I perform infernece, I got error
ERROR: [Torch-TensorRT] - IExecutionContext::enqueueV3: Error Code 1: Cuda Runtime (invalid resource handle)

Expected behavior

I expect the inference to succeed. I want to serve Torch-TensorRT optimized model on Nvidia-Triton. Our team observed that, on models like SAM2, Torch-TensorRT is significantly faster than (Torch -> onnx -> TensorRT) converted model. Our entire inference stack is on Nvidia-Triton, and we would like to take advantage of this new tool.

Environment

We use directly Nvidia NGC docker.
Pytorch for model optimiztion: nvcr.io/nvidia/pytorch:24.09-py3
Triton for hosting: nvcr.io/nvidia/tritonserver:24.09-py3

Additional context

Actually our current stack is on tritonserver:24.03, and we tested that it does not work with nvcr.io/nvidia/tritonserver:24.03py3 and nvcr.io/nvidia/pytorch:24.03-py3

Pleaes let us know if you need additional information

@narendasan
Copy link
Collaborator

Seems like you are mixing dynamo and torchscript. There are two options. 1. use dynamo to trace and deploy in torchscript (this is what we recommend)

import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")

# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
    #ir="dynamo" implicitly 
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    enabled_precisions= { torch.half} # Run with FP32
)

# Save the model
torch_tensorrt.save(trt_model, "model.ts", output_format="torchscript", inputs=torch.randn((1,3,224,224))

(https://pytorch.org/TensorRT/user_guide/saving_models.html)

Alternatively 2. Use the torchscript frontend

import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")

# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
    ir="torchscript",
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    enabled_precisions= { torch.half} # Run with FP32
)

# Save the model
torch.jit.save(trt_model, "model.pt")

@zmy1116
Copy link
Author

zmy1116 commented Oct 18, 2024

@narendasan Thanks for the quick reply.

If I run the first script you provided

import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")

# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
    #ir="dynamo" implicitly 
    inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
    enabled_precisions= { torch.half} # Run with FP32
)

# Save the model
torch_tensorrt.save(trt_model, "model.ts", output_format="torchscript", inputs=torch.randn((1,3,224,224))

I got this error at the end when saving the model

RuntimeError                              Traceback (most recent call last)
Cell In[3], line 16
      9 trt_model = torch_tensorrt.compile(model,
     10     #ir="dynamo" implicitly 
     11     inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
     12     enabled_precisions= { torch.half} # Run with FP32
     13 )
     15 # Save the model
---> 16 torch_tensorrt.save(trt_model, "model.ts", output_format="torchscript", inputs=torch.randn((1,3,224,224)))

File /usr/local/lib/python3.10/dist-packages/torch_tensorrt/_compile.py:481, in save(module, file_path, output_format, inputs, arg_inputs, kwarg_inputs, retrace)
    476 if arg_inputs and inputs:
    477     raise AssertionError(
    478         "'arg_inputs' and 'inputs' should not be used at the same time."
    479     )
--> 481 arg_inputs = inputs or arg_inputs
    483 if kwarg_inputs is None:
    484     kwarg_inputs = {}

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

With the second script I can get the TS model. However, when I try to perform inference wiht this model on triton server.
I still get the same error ERROR: [Torch-TensorRT] - IExecutionContext::enqueueV3: Error Code 1: Cuda Runtime (invalid resource handle)

Can anyone on your end confirm that if the torch-tensorrt optimized model can (or cannot) run on nvidia triton? (basically confirm this tutorial work/not work https://pytorch.org/TensorRT/tutorials/serving_torch_tensorrt_with_triton.html)

because this tutorial has been on the torch-tensorrt page for a while, for the past year I've tried multiple times over multiple triton server/torch tensorrt versions.. it never worked.

@narendasan
Copy link
Collaborator

will poke around, might just be that the tutorial is outdated

narendasan added a commit that referenced this issue Nov 12, 2024
Fixes: #3248

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
narendasan added a commit that referenced this issue Nov 12, 2024
Fixes: #3248

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
@narendasan
Copy link
Collaborator

narendasan commented Nov 12, 2024

@zmy1116 Updated the triton tutorial, seems like there were some subtle things that could be off but for the most part nothing has changed. I uploaded scripts that I have verified to work for exporting and querying a resnet model. Hopefully that is enough to go on.

The TL;DR of that tutorial in #3292 is if you check out that branch and go to //examples/triton then run in one terminal

# Could be any recent publish tag (I tested with 24.08), just use the same for all containers so that the TRT versions are the same  

# Export model into model repo 
docker run --gpus all -it --rm -v ${PWD}:/triton_example nvcr.io/nvidia/pytorch:24.08-py3 python /triton_example/export.py

# Start server 
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:24.08-py3 tritonserver --model-repository=/triton_example/model_repository

and in another terminal

# Get a sample image
wget -O img1.jpg "https://www.hakaimagazine.com/wp-content/uploads/header-gulf-birds.jpg"

# Query server
docker run -it --net=host -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:24.08-py3-sdk bash -c "pip install torchvision && python /triton_example/client.py"

You should get an output like:

[b'12.460938:90' b'11.523438:92' b'9.656250:14' b'8.414062:136'
 b'8.210938:11']

You can take a look at config.pbtxt for the Triton config I used. I would recommend using explicit dim sizes when possible

narendasan added a commit that referenced this issue Nov 13, 2024
Fixes: #3248

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
@zmy1116
Copy link
Author

zmy1116 commented Nov 15, 2024

@narendasan

Thanks for the fix. I tested and confirmed that at least for resnet50, it works on the newest ngc version 24.10

I also did some comparison with code I was using to create model/start engine/inference. The difference I see are :

  • torch_tensorrt.compile related
  • you use backend: pytorch, I used platform: "pytorch_libtorch"
  • you have max_batch_size: 0 and put batch size in the input/output dim, I have max_batch_size:1 and not have explicit batch size at input/output dim
  • you use http client to do inference, I use our internal tool based on grpc client.

I verified all these differences, and I confirm that none of these changes really cause the problem...

I then tried with `24.09' version (the version I was using when creating the bug report). I found out the issue is following:

  • If I only launch triton server on 1 GPU, there will be no problem.
  • If I laucch triton server on more GPUs, the sometimes I get the cuda error.
    • The more GPU I have, the higher chance the error occur: with 2 gpus, around 50% calls fail. with 4 gpus, around 75% calls fail.
      Since at the l created the issue I was using a 8 gpu machine, it may just be that with high probability things did not work back then.

I tested both 24.09 and 24.03 (the version we currently serve our models). I can confirm that with 1 gpu the server runs correctly and error rate increase linearly as number of gpu used.

Thank you.

@narendasan
Copy link
Collaborator

narendasan commented Nov 18, 2024

Hmm, well its good that at least 1 GPU works, I think at this point the folks in https://github.com/triton-inference-server/server would be better able to debug what is happening. From our side we mostly focus on model export and the runtime extension and they handle all of the orchestration stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants