Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
eegli committed Jan 28, 2025
1 parent 0fe26f0 commit 1a4f92b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/mblm/model/mblm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,18 @@ def __init__(self, cfg: MBLMModelConfig):
self.start_tokens = nn.ParameterList(
[nn.Parameter(torch.randn(model_dim)) for model_dim in cfg.hidden_dims]
)
stage_blocks = cfg.stage_blocks()

self.pos_embs = self._init_positional_embeddings(
cfg.hidden_dims, cfg.seq_lens, cfg.stage_blocks
cfg.hidden_dims, cfg.seq_lens, stage_blocks
)

self.token_embs_rev = self._init_token_embeddings(
cfg.hidden_dims, cfg.seq_lens, cfg.num_tokens, cfg.pad_token_id
)

self.stage_models, self.to_next_stage_proj = self._init_models_at_stages(
cfg.hidden_dims, cfg.seq_lens, cfg.num_layers, cfg.stage_blocks
cfg.hidden_dims, cfg.seq_lens, cfg.num_layers, stage_blocks
)

self.to_logits = nn.Linear(cfg.hidden_dims[-1], cfg.num_tokens)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/config/test_sample_config_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def ensure_dataset_args_are_valid(self, config: TrainEntryConfig) -> None:
return None

def ensure_model_is_created(self, config: TrainEntryConfig) -> None:
for b in config.params.stage_blocks:
for b in config.params.stage_blocks():
assert isinstance(b, (TransformerBlock, MambaBlock))
if isinstance(b, TransformerBlock):
assert b.block_type == "transformer"
Expand Down

0 comments on commit 1a4f92b

Please sign in to comment.