diff --git a/collect-deepseek-predictor-data.py b/collect-deepseek-predictor-data.py index e4a7873..79d10da 100644 --- a/collect-deepseek-predictor-data.py +++ b/collect-deepseek-predictor-data.py @@ -361,6 +361,7 @@ def _custom_ffn_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for name, module in model.named_modules(): if type(module).__name__ == 'DeepSeekMoE': block_ffn_input_output_pair[name] = [] + module._original_forward = module.forward module._module_name = name module.forward = _custom_ffn_forward.__get__(module, type(module))