Skip to content

🐛 [Bug] Difficulties Quantizing FP16 Models to INT8 Using torch_tensorrt (MLP, CNN, Attention, LSTM, Transformer) #3494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
NisFu-gh opened this issue Apr 27, 2025 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@NisFu-gh
Copy link

NisFu-gh commented Apr 27, 2025

Bug Description

I am trying to quantize already trained FP16 models to INT8 precision using torch_tensorrt and accelerate inference with TensorRT engines. However, during this process, I encountered several different issues — either inside torch_tensorrt or TensorRT itself (not entirely sure).

In most cases, the models fail to pass the quantize and/or compile process.

To Reproduce

  1. Define several common models (MLP, CNN, Attention, LSTM, Transformer) in torch.
  2. Randomly initialize model weights.
  3. Convert the models to FP16 precision and move them to GPU.
  4. Compile models using torch_tensorrt:
    • Compile to FP16 TensorRT engine.
    • Compile and quantize to INT8 TensorRT engine.
  5. Compare inference performance and accuracy between:
    • Original FP16 model
    • FP16 TensorRT-compiled model
    • INT8 TensorRT-compiled-quantized model

Here is the minimal reproducible code:

# (One can switch between different models and IRs by modifying the comments.)
import contextlib
import time

import torch
import torch.nn as nn
import torch_tensorrt
from torch.utils.data import DataLoader, Dataset
from torch_tensorrt.ts import ptq

class MLPModel(nn.Module):
    def __init__(self, seq_len, in_dim, hidden_sizes):
        super().__init__()
        layers = []
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(in_dim, hidden_size))
            in_dim = hidden_size
        self.mlp = nn.Sequential(*layers)
        self.end_layer = nn.Linear(in_dim, 48)
        self.active = nn.ReLU()

    def forward(self, x):
        B, T, C = x.shape
        x = x.transpose(0, 1)  # (T, B, C)
        x = self.mlp(x)
        x = self.active(x)
        x = self.end_layer(x)
        x = x.mean(dim=0)  # (B, 48)
        return x

class CNNModel(nn.Module):
    def __init__(self, seq_len, in_dim, num_layers):
        super().__init__()
        layers = []
        input_channels = in_dim
        for _ in range(num_layers):
            layers.append(nn.Conv1d(input_channels, 256, kernel_size=3, padding=1))
            layers.append(nn.ReLU())
            input_channels = 256
        self.conv = nn.Sequential(*layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, 48)

    def forward(self, x):
        B, T, C = x.shape
        x = x.transpose(1, 2)  # (B, C, T)
        x = self.conv(x)       # (B, 256, T)
        x = self.pool(x)       # (B, 256, 1)
        x = x.squeeze(-1)      # (B, 256)
        x = self.fc(x)         # (B, 48)
        return x

class AttentionModel(nn.Module):
    def __init__(self, seq_len, in_dim, num_layers):
        super().__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(nn.MultiheadAttention(embed_dim=in_dim, num_heads=4, batch_first=True))
        self.attention_layers = nn.ModuleList(layers)
        self.fc = nn.Linear(in_dim, 48)

    def forward(self, x):
        B, T, C = x.shape
        for attn in self.attention_layers:
            x, _ = attn(x, x, x)
        x = x.mean(dim=1)  # (B, C)
        x = self.fc(x)     # (B, 48)
        return x

class TransformerModel(nn.Module):
    def __init__(self, seq_len, in_dim, num_layers):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=in_dim, nhead=4, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(in_dim, 48)

    def forward(self, x):
        B, T, C = x.shape
        x = self.transformer(x)  # (B, T, C)
        x = x.mean(dim=1)        # (B, C)
        x = self.fc(x)           # (B, 48)
        return x

class LSTMModel(nn.Module):
    def __init__(self, seq_len, in_dim, num_layers):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=in_dim,
            hidden_size=256,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=False,
        )
        self.fc = nn.Linear(256, 48)

    def forward(self, x):
        B, T, C = x.shape
        output, (hn, cn) = self.lstm(x)  # output: (B, T, 256)
        x = output.mean(dim=1)           # (B, 256)
        x = self.fc(x)                   # (B, 48)
        return x

@torch.no_grad()
def run_model_with_profiling(
    model, example_inputs, num_warmup, num_runs, msg=""
):
    for _ in range(num_warmup):
        _ = model(*example_inputs)

    torch.cuda.synchronize()
    start_time = time.time()

    for _ in range(num_runs):
        output = model(*example_inputs)
        torch.cuda.synchronize()

    torch.cuda.synchronize()
    end_time = time.time()
    avg_time = (end_time - start_time) * 1000 / num_runs
    print(f"{msg} {avg_time=:.4f} ms, {output.dtype=}")
    return output


class CalibrationDataset(Dataset):
    def __init__(self, length, shape):
        self.length = length
        self.shape = shape

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return torch.rand(self.shape).half().cuda()

sample_shape = (1024, 512)
calib_num = 256
batch_size = 256
warmup = 5
infer = 10
profile = False
# ir = "torch_compile" # "torchscript", "dynamo"

seq_len, item_len = 256, 512
model = MLPModel(seq_len, item_len, [1024] * 5)
# model = CNNModel(seq_len, item_len, num_layers=5)
# model = AttentionModel(seq_len, item_len, num_layers=2)
# model = TransformerModel(seq_len, item_len, num_layers=2)
# model = LSTMModel(seq_len, item_len, num_layers=2)

calib_dataset = CalibrationDataset(calib_num, sample_shape)
calib_dataloader = DataLoader(calib_dataset, batch_size=32, shuffle=False)

model = model.eval().half().cuda()
example_input = torch.randn(batch_size, *sample_shape).half().cuda()
print(f">>> model: {model.__class__}")

try:
    model(example_input)
    fp16_output = run_model_with_profiling(model, [example_input, ], warmup, infer, "fp16")
except Exception as e:
    print(f">>> original model error: {e}")
else:
    print(">>> original model passed")

try:
    compiled_model = torch_tensorrt.compile(
        model,
        ir=ir,
        inputs=[
            torch_tensorrt.Input(
                min_shape=[batch_size, *sample_shape],
                opt_shape=[batch_size, *sample_shape],
                max_shape=[batch_size, *sample_shape],
                dtype=torch.half,
            )
        ],
        enabled_precisions={torch.float16},
        truncate_long_and_double=True,
    )
    _ = run_model_with_profiling(compiled_model, [example_input, ], warmup, infer, "trt")
except Exception as e:
    print(f">>> tensorrt compile error: {e}")
else:
    print(">>> tensorrt compile passed")

try:
    quantized_compiled_model = torch_tensorrt.compile(
        model,
        ir=ir,
        inputs=[
            torch_tensorrt.Input(
                min_shape=[batch_size, *sample_shape],
                opt_shape=[batch_size, *sample_shape],
                max_shape=[batch_size, *sample_shape],
                dtype=torch.half,
            )
        ],
        enabled_precisions={torch.int8},
        calibrator=ptq.DataLoaderCalibrator(
            calib_dataloader,
            algo_type=ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
            device=torch.device("cuda:0"),
        ),
        truncate_long_and_double=True,
    )
    int8_output = run_model_with_profiling(quantized_compiled_model, [example_input, ], warmup, infer, "quant + trt")
except Exception as e:
    print(f">>> quantized tensorrt compile error: {e}")
else:
    print(">>> quantized tensorrt compile passed")

Expected Behavior

Successfully compile FP16 models to INT8 TensorRT engines, also maintain reasonable inference accuracy and performance.

Actual Behavior

In most cases, the compilation fails or the resulting models cannot run correctly. Below is a summary to the results that I tested:

Model IR FP16 FP16 + TRT INT8 + TRT Error Log
MLP torch_compile pass pass failed see [1]
MLP torchscript pass pass pass N/A
MLP dynamo pass pass failed see [2]
CNN torch_compile pass pass failed see [3]
CNN torchscript pass pass failed see [4]
CNN dynamo pass pass failed see [5]
Attention torch_compile pass pass pass N/A
Attention torchscript pass failed failed see [6]
Attention dynamo pass failed failed see [7] [8]
Transformer torch_compile pass pass pass N/A
Transformer torchscript pass failed failed see [9]
Transformer dynamo pass failed failed see [10] [11]
LSTM torch_compile pass pass pass N/A
LSTM torchscript pass failed failed see [12]
LSTM dynamo pass failed failed see [13]

And the corresponding error log is (due to the length limitation I must upload a file) error_log.txt.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0+cu124
  • PyTorch Version (e.g. 1.0): 2.6.0+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Rocky Linux 8.7
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): N/A
  • Are you using local sources or building from archives: N/A
  • Python version: 3.10.16
  • CUDA version: 12.5
  • GPU models and configuration: NVIDIA GeForce RTX 4090
  • Any other relevant information: N/A

Questions

Am I using torch_tensorrt incorrectly?
Are there any important documentation notes or best practices regarding compilation and quantization that I might have missed?
Whats the correct way (or official suggestion) to do this task, specifically given a fp16 model then build a int8 quantized version and inference the model with TensorRT backend?

Any help would be greatly appreciated! Thank you in advance!

Additional context

N/A

@NisFu-gh NisFu-gh added the bug Something isn't working label Apr 27, 2025
@narendasan
Copy link
Collaborator

Thanks for filing such a detailed bug report, we havent had a chance to dig into it but generally using the DataloaderCalibrator from the TorchScript frontend is deprecated as it uses deprecated TRT APIs. The new workflow uses the TensorRT Model Optimizer Toolkit https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_ptq.html

@NisFu-gh
Copy link
Author

Hi @narendasan thank you for your reply! Recently, I’ve done some further exploration and gained a better understanding of the issue I previously raised. I’d like to update this issue here (as well as clarify some of my earlier misunderstandings about torch_tensorrt).

First of all, regarding the model optimization workflow you mentioned using modelopt, I did follow the documentation and tried it out. However, in that workflow, after optimizing the model with modelopt, it still needs to be compiled using torch_tensorrt. During the compilation process, I encountered the same error I described in my original post. In short, I believe simply replacing the original calibration process with modelopt does not fully resolve all the errors I’m facing. I suspect some of the issues are actually caused by torch_tensorrt.compile itself.

Next, I want to clarify a mistake I made in my original post. Specifically, in the error log I shared, some of the errors are:

ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (Calibration failure occurred with no scaling factors detected. This could be due to no int8 calibrator or insufficient custom scales for network layers. Please see int8 sample to setup calibration correctly.)

After trying to manually perform model quantization using the TensorRT Python API, I found that this error can partly be attributed to a mismatch between the batch size of the calibration dataset and the opt_shape profile used for model optimization. For example, in the demo code, I set the calibration dataset shape to (1, 1024, 512), but the model’s opt_shape was (256, 1024, 512). I believe this is an important detail that should be explicitly stated in the documentation — because, to be honest, the error message is quite ambiguous.

Lastly, in case other users encounter similar issues, I’d like to share a workaround that has worked for me. As mentioned above, I now directly use the TensorRT Python API to perform model quantization. The full workflow is: first export the model to ONNX using PyTorch, then use TensorRT to load the ONNX model, perform calibration and quantization, and finally build the inference engine. Using this approach, I was able to successfully quantize all five of the toy models shown in the sample code to INT8, while maintaining very good accuracy. One potential issue is that ONNX model serialization has a 2GB size limit, so for larger models, alternative methods may be needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants