Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Nov 20, 2024
1 parent 1a8a842 commit 346023e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions nemo/collections/llm/gpt/model/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@
import torch.nn.functional as F

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'F' is not used.
from torch import nn
from typing_extensions import Annotated

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Annotated' is not used.
from nemo.lightning.pytorch.optim import OptimizerModule
from pathlib import Path

from nemo.collections.llm.gpt.model.baichuan import Baichuan2Model

# import nemo.collections.llm.gpt.model as llm
from nemo.collections.llm.gpt.model.base import GPTModel
from nemo.collections.llm.gpt.model.chatglm import ChatGLMModel
from nemo.collections.llm.gpt.model.phi3mini import Phi3Model
from nemo.collections.llm.gpt.model.mixtral import MixtralModel
from nemo.collections.llm.gpt.model.mistral import MistralModel
from nemo.collections.llm.gpt.model.gemma import GemmaModel
from nemo.collections.llm.gpt.model.gemma2 import Gemma2Model
from nemo.collections.llm.gpt.model.llama import LlamaModel
from nemo.collections.llm.gpt.model.baichuan import Baichuan2Model
from nemo.collections.llm.gpt.model.mistral import MistralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralModel
from nemo.collections.llm.gpt.model.phi3mini import Phi3Model
from nemo.collections.llm.gpt.model.qwen2 import Qwen2Model
from nemo.collections.llm.gpt.model.starcoder import StarcoderModel
from nemo.collections.llm.gpt.model.starcoder2 import Starcoder2Model
from nemo.lightning.pytorch.optim import OptimizerModule

HF_TO_MCORE_REGISTRY = {
'ChatGLMForCausalLM': ChatGLMModel,
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(

# Get model class from registry
from transformers import AutoConfig

architectures = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True).architectures
assert isinstance(architectures, list), "Expected architectures to be a list"
assert len(architectures) == 1, "Expected architectures to contain one item"
Expand All @@ -90,13 +91,12 @@ def __init__(
config = model_cls.importer(import_path).config

# Init class
super().__init__(
config, optim=optim, tokenizer=tokenizer, model_transform=model_transform
)
super().__init__(config, optim=optim, tokenizer=tokenizer, model_transform=model_transform)

# Change self's class to model_cls
self.__class__ = model_cls


__all__ = [
"AutoModel",
]

0 comments on commit 346023e

Please sign in to comment.