Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Pavel Geyn committed Jan 21, 2025
1 parent 386dfed commit c407488
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ def load_model(
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

if model_settings.sequence_parallel_degree:
import turbo_alignment.modeling.parallel_states as parallel_states
from turbo_alignment.modeling.gemma2.patch import patch_gemma_attn_dict
# if model_settings.sequence_parallel_degree:
# import turbo_alignment.modeling.parallel_states as parallel_states
# from turbo_alignment.modeling.gemma2.patch import patch_gemma_attn_dict

patch_gemma_attn_dict()
# patch_gemma_attn_dict()

parallel_states.initialize_model_parallel(sequence_parallel_size=model_settings.sequence_parallel_degree)
assert parallel_states.sequence_parallel_is_initialized()
# parallel_states.initialize_model_parallel(sequence_parallel_size=model_settings.sequence_parallel_degree)
# assert parallel_states.sequence_parallel_is_initialized()

model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained(
model_settings.model_path,
Expand Down

0 comments on commit c407488

Please sign in to comment.