diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index 43b061f9..f7a5ae66 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -78,14 +78,14 @@ def load_model( fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, ) - if model_settings.sequence_parallel_degree: - import turbo_alignment.modeling.parallel_states as parallel_states - from turbo_alignment.modeling.gemma2.patch import patch_gemma_attn_dict + # if model_settings.sequence_parallel_degree: + # import turbo_alignment.modeling.parallel_states as parallel_states + # from turbo_alignment.modeling.gemma2.patch import patch_gemma_attn_dict - patch_gemma_attn_dict() + # patch_gemma_attn_dict() - parallel_states.initialize_model_parallel(sequence_parallel_size=model_settings.sequence_parallel_degree) - assert parallel_states.sequence_parallel_is_initialized() + # parallel_states.initialize_model_parallel(sequence_parallel_size=model_settings.sequence_parallel_degree) + # assert parallel_states.sequence_parallel_is_initialized() model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained( model_settings.model_path,