diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py index 2d9c41b603..bd31ca2866 100644 --- a/examples/models/llama3_2_vision/text_decoder/model.py +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -108,6 +108,7 @@ def __init__(self, **kwargs): rope_base=params["rope_theta"], intermediate_dim=params["intermediate_dim"], ) + self.model_.requires_grad_(False) # Source transformation for MultiHeadAttention self.model_ = replace_mha_with_inference_mha(self.model_)