From 8f04255bed6e15282ef4886227914af5ce93b08b Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 30 Oct 2024 01:31:21 -0400 Subject: [PATCH] More Pre-Commit Fixes, I need to make this actually run pre-commit --- src/levanter/data/audio.py | 3 ++- src/levanter/main/train_asr.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index 017c09a65..4b5d4fb36 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -81,7 +81,7 @@ def __init__( padding=True, ): self.feature_extractor: SequenceFeatureExtractor = processor.feature_extractor - if tokenizer.pad_token_id == None: + if tokenizer.pad_token_id is None: override_token = list(tokenizer.added_tokens_decoder.items())[-1] tokenizer.pad_token_id = override_token[0] tokenizer.pad_tokn = str(override_token[1]) @@ -275,6 +275,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]): def __init__(self, cache: TreeCache[AudioTextDict]): self.cache = cache + self._cached_len: Optional[int] = None async def async_len(self) -> int: return await self.cache.async_len() diff --git a/src/levanter/main/train_asr.py b/src/levanter/main/train_asr.py index 59f507d98..86cc04cb1 100644 --- a/src/levanter/main/train_asr.py +++ b/src/levanter/main/train_asr.py @@ -16,8 +16,8 @@ from levanter.compat.hf_checkpoints import HFCompatConfig, ModelWithHfSerializationMixin, save_hf_checkpoint_callback from levanter.data.audio import AudioIODatasetConfig, AudioMixtureDatasetConfig, AudioTextDataset from levanter.models.asr_model import ASRConfig, AudioTextExample +from levanter.models.diva import DivaASRModel, diva_connector_only from levanter.models.whisper import WhisperConfig -from levanter.models.diva import diva_connector_only from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -138,12 +138,15 @@ def compute_loss( if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - state = trainer.initial_state(training_key, model_init=lambda: config.model.build_asr(Vocab, key=model_key), ) + state = trainer.initial_state( + training_key, + model_init=lambda: config.model.build_asr(Vocab, key=model_key), + ) if int(state.step) == 0: - if config.diva_training: + if config.diva_training and config.model.asr_model_type == DivaASRModel: state = dataclasses.replace(state, model=None) - model = config.model.asr_model_type.init(Vocab, config.model, key=model_key, init_from_submodels=True) + model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True) model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model)) # TODO: I don't love that we init the model twice, but it's not a big deal i think?