Skip to content

Commit

Permalink
fix: pydantic warnings about model_ namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
flxst committed Aug 2, 2024
1 parent 65bf611 commit a29fe8c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from omegaconf import OmegaConf
from pydantic import BaseModel, Field, FilePath, PositiveInt, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, FilePath, PositiveInt, field_validator, model_validator
from torch.distributed.fsdp import ShardingStrategy
from transformers import GPT2TokenizerFast
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
Expand Down Expand Up @@ -234,6 +234,10 @@ class WeightInitializedModelConfig(BaseModel):
model: PydanticPytorchModuleType
model_initializer: PydanticModelInitializationIFType

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
model_config = ConfigDict(protected_namespaces=())


class PreTrainedHFTokenizerConfig(BaseModel):
pretrained_model_name_or_path: str
Expand Down
6 changes: 5 additions & 1 deletion src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Annotated, Dict, List, Optional

from pydantic import BaseModel, Field, FilePath, field_validator
from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator

from modalities.config.pydanctic_if_types import (
PydanticCheckpointSavingIFType,
Expand Down Expand Up @@ -82,6 +82,10 @@ class TextGenerationSettings(BaseModel):
device: PydanticPytorchDeviceType
referencing_keys: Dict[str, str]

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
model_config = ConfigDict(protected_namespaces=())

@field_validator("device", mode="before")
def parse_device(cls, device) -> PydanticPytorchDeviceType:
return parse_torch_device(device)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional

import torch.nn as nn
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated

from modalities.config.pydanctic_if_types import PydanticModelInitializationIFType
Expand All @@ -17,6 +17,10 @@
class ModelInitializerWrapperConfig(BaseModel):
model_initializers: List[PydanticModelInitializationIFType]

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
model_config = ConfigDict(protected_namespaces=())


class ComposedModelInitializationConfig(BaseModel):
model_type: SupportWeightInitModels
Expand All @@ -27,6 +31,10 @@ class ComposedModelInitializationConfig(BaseModel):
hidden_dim: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
num_layers: Optional[Annotated[int, Field(strict=True, gt=0)]] = None

# avoid warning about protected namespace 'model_', see
# https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces
model_config = ConfigDict(protected_namespaces=())

@model_validator(mode="after")
def _check_values(self):
# in case of initialization with "auto", we need to specify the hidden_dim
Expand Down

0 comments on commit a29fe8c

Please sign in to comment.