Skip to content

Commit

Permalink
docs: Updated tutorial for triton + torch-tensorrt
Browse files Browse the repository at this point in the history
Fixes: #3248

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Nov 12, 2024
1 parent a405443 commit 06484f7
Show file tree
Hide file tree
Showing 6 changed files with 422 additions and 61 deletions.
143 changes: 82 additions & 61 deletions docsrc/tutorials/serving_torch_tensorrt_with_triton.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,55 @@ Step 1: Optimize your model with Torch-TensorRT
Most Torch-TensorRT users will be familiar with this step. For the purpose of
this demonstration, we will be using a ResNet50 model from Torchhub.

Let’s first pull the `NGC PyTorch Docker container <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`__. You may need to create
We will be working in the ``//examples/triton`` directory which contains the scripts used in this tutorial.

First pull the `NGC PyTorch Docker container <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`__. You may need to create
an account and get the API key from `here <https://ngc.nvidia.com/setup/>`__.
Sign up and login with your key (follow the instructions
`here <https://ngc.nvidia.com/setup/api-key>`__ after signing up).

::

# <xx.xx> is the yy:mm for the publishing tag for NVIDIA's Pytorch
# container; eg. 22.04
# YY.MM is the yy:mm for the publishing tag for NVIDIA's Pytorch
# container; eg. 24.08
# NOTE: Use the publishing tag for both the PyTorch container and the Triton Containers

docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:<xx.xx>-py3
docker run -it --gpus all -v ${PWD}:/scratch_space nvcr.io/nvidia/pytorch:YY.MM-py3
cd /scratch_space

Once inside the container, we can proceed to download a ResNet model from
Torchhub and optimize it with Torch-TensorRT.
With the container we can export the model in to the correct directory in our Triton model repository. This export script uses the **Dynamo** frontend for Torch-TensorRT to compile the PyTorch model to TensorRT. Then we save the model using **TorchScript** as a serialization format which is supported by Triton.

::

import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
import torch
import torch_tensorrt as torchtrt
import torchvision

import torch
import torch_tensorrt
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")

# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
enabled_precisions= {torch_tensorrt.dtype.f16}
)

ts_trt_model = torch.jit.trace(trt_model, torch.rand(1, 3, 224, 224).to("cuda"))

# Save the model
torch.jit.save(ts_trt_model, "/triton_example/model_repository/resnet50/1/model.pt")

# load model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True).eval().to("cuda")
You can run the script with the following command (from ``//examples/triton``)

# Compile with Torch TensorRT;
trt_model = torch_tensorrt.compile(model,
inputs= [torch_tensorrt.Input((1, 3, 224, 224))],
enabled_precisions= { torch.half} # Run with FP32
)
::

# Save the model
torch.jit.save(trt_model, "model.pt")
docker run --gpus all -it --rm -v ${PWD}:/triton_example nvcr.io/nvidia/pytorch:YY.MM-py3 python /triton_example/export.py

After copying the model, exit the container. The next step in the process
is to set up a Triton Inference Server.
This will save the serialized TorchScript version of the ResNet model in the right directory in the model repository.

Step 2: Set Up Triton Inference Server
--------------------------------------
Expand Down Expand Up @@ -90,25 +103,23 @@ For the model we prepared in step 1, the following configuration can be used:

::

name: "resnet50"
platform: "pytorch_libtorch"
max_batch_size : 0
input [
{
name: "input__0"
data_type: TYPE_FP32
dims: [ 3, 224, 224 ]
reshape { shape: [ 1, 3, 224, 224 ] }
}
]
output [
{
name: "output__0"
data_type: TYPE_FP32
dims: [ 1, 1000 ,1, 1]
reshape { shape: [ 1, 1000 ] }
}
]
name: "resnet50"
backend: "pytorch"
max_batch_size : 0
input [
{
name: "x"
data_type: TYPE_FP32
dims: [ 1, 3, 224, 224 ]
}
]
output [
{
name: "output0"
data_type: TYPE_FP32
dims: [1, 1000]
}
]

The ``config.pbtxt`` file is used to describe the exact model configuration
with details like the names and shapes of the input and output layer(s),
Expand All @@ -124,9 +135,9 @@ with the docker command below. Refer `this page <https://catalog.ngc.nvidia.com/

# Make sure that the TensorRT version in the Triton container
# and TensorRT version in the environment used to optimize the model
# are the same.
# are the same. Roughly, like publishing tags should have the same TensorRT version

docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /full/path/to/the_model_repository/model_repository:/models nvcr.io/nvidia/tritonserver:<xx.yy>-py3 tritonserver --model-repository=/models
docker run --gpus all --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:24.08-py3 tritonserver --model-repository=/triton_example/model_repository

This should spin up a Triton Inference server. Next step, building a simple
http client to query the server.
Expand Down Expand Up @@ -159,22 +170,24 @@ resize and normalize the query image.

::

import numpy as np
from torchvision import transforms
from PIL import Image
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype

# preprocessing function
def rn50_preprocess(img_path="img1.jpg"):
img = Image.open(img_path)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return preprocess(img).numpy()
import numpy as np
from torchvision import transforms
from PIL import Image
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype

# preprocessing function
def rn50_preprocess(img_path="/triton_example/img1.jpg"):
img = Image.open(img_path)
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return preprocess(img).unsqueeze(0).numpy()

transformed_img = rn50_preprocess()

Expand All @@ -186,22 +199,22 @@ with the Triton Inference Server.
# Setting up client
client = httpclient.InferenceServerClient(url="localhost:8000")

Secondly, we specify the names of the input and output layer(s) of our model.
Secondly, we specify the names of the input and output layer(s) of our model. This can be obtained during export and should already be specified in your ``config.pbtxt``

::

inputs = httpclient.InferInput("input__0", transformed_img.shape, datatype="FP32")
inputs = httpclient.InferInput("x", transformed_img.shape, datatype="FP32")
inputs.set_data_from_numpy(transformed_img, binary_data=True)

outputs = httpclient.InferRequestedOutput("output__0", binary_data=True, class_count=1000)
outputs = httpclient.InferRequestedOutput("output0", binary_data=True, class_count=1000)

Lastly, we send an inference request to the Triton Inference Server.

::

# Querying the server
results = client.infer(model_name="resnet50", inputs=[inputs], outputs=[outputs])
inference_output = results.as_numpy('output__0')
inference_output = results.as_numpy('output0')
print(inference_output[:5])

The output should look like below:
Expand All @@ -214,3 +227,11 @@ The output should look like below:
The output format here is ``<confidence_score>:<classification_index>``.
To learn how to map these to the label names and more, refer to Triton Inference Server's
`documentation <https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_classification.md>`__.

You can try out this client quickly using

::

# Remember to use the same publishing tag for all steps (e.g. 24.08)

docker run -it --net=host -v ${PWD}:/triton_example nvcr.io/nvidia/tritonserver:YY.MM-py3-sdk bash -c "pip install torchvision && python /triton_example/client.py"
Loading

0 comments on commit 06484f7

Please sign in to comment.