Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Nov 20, 2024
1 parent 3e9b787 commit 57ca23f
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions nemo/collections/llm/gpt/model/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Callable, Optional

from torch import nn
from nemo.lightning.pytorch.optim import OptimizerModule
from pathlib import Path

from nemo.collections.llm.gpt.model.baichuan import Baichuan2Model
from nemo.collections.llm.gpt.model.base import GPTModel
from nemo.collections.llm.gpt.model.chatglm import ChatGLMModel
from nemo.collections.llm.gpt.model.phi3mini import Phi3Model
from nemo.collections.llm.gpt.model.mixtral import MixtralModel
from nemo.collections.llm.gpt.model.mistral import MistralModel
from nemo.collections.llm.gpt.model.gemma import GemmaModel
from nemo.collections.llm.gpt.model.gemma2 import Gemma2Model
from nemo.collections.llm.gpt.model.llama import LlamaModel
from nemo.collections.llm.gpt.model.baichuan import Baichuan2Model
from nemo.collections.llm.gpt.model.mistral import MistralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralModel
from nemo.collections.llm.gpt.model.phi3mini import Phi3Model
from nemo.collections.llm.gpt.model.qwen2 import Qwen2Model
from nemo.collections.llm.gpt.model.starcoder import StarcoderModel
from nemo.collections.llm.gpt.model.starcoder2 import Starcoder2Model
from nemo.lightning.pytorch.optim import OptimizerModule

HF_TO_MCORE_REGISTRY = {
'ChatGLMForCausalLM': ChatGLMModel,
Expand Down Expand Up @@ -67,6 +67,7 @@ def __init__(

# Get model class from registry
from transformers import AutoConfig

architectures = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True).architectures
assert isinstance(architectures, list), "Expected architectures to be a list"
assert len(architectures) == 1, "Expected architectures to contain one item"
Expand All @@ -84,13 +85,12 @@ def __init__(
config = model_cls.importer(import_path).config

# Init class
super().__init__(
config, optim=optim, tokenizer=tokenizer, model_transform=model_transform
)
super().__init__(config, optim=optim, tokenizer=tokenizer, model_transform=model_transform)

# Change self's class to model_cls
self.__class__ = model_cls


__all__ = [
"AutoModel",
]

0 comments on commit 57ca23f

Please sign in to comment.