Skip to content

Commit

Permalink
[#102] initial upgrade to pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
pkdash committed Oct 9, 2024
1 parent 9df9120 commit 206da75
Show file tree
Hide file tree
Showing 15 changed files with 192 additions and 132 deletions.
2 changes: 1 addition & 1 deletion hydroshare_on_jupyter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _load_jupyter_server_extension(server_app: ServerApp):
config = ConfigFile()

# pass config file settings to Tornado Application (web app)
server_app.web_app.settings.update(config.dict())
server_app.web_app.settings.update(config.model_dump())


# For backward compatibility with the classical notebook
Expand Down
2 changes: 1 addition & 1 deletion hydroshare_on_jupyter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def configure_jupyter() -> None:
def start_stand_alone_session(
hostname: str, port: int, debug: bool, config: ConfigFile
) -> None:
app = get_test_app(default_hostname=hostname, debug=debug, **config.dict())
app = get_test_app(default_hostname=hostname, debug=debug, **config.model_dump())

logging.info(f"Server starting on {hostname}:{port}")
logging.info(f"Debugging mode {'enabled' if debug else 'disabled'}")
Expand Down
52 changes: 31 additions & 21 deletions hydroshare_on_jupyter/config_setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseSettings, Field, root_validator, validator
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
import pickle
from pathlib import Path
from typing import Optional, Union
Expand All @@ -21,27 +22,36 @@ def __init__(self, message: str) -> None:

class ConfigFile(BaseSettings):
# case-insensitive alias values DATA and LOG
data_path: Path = Field(_DEFAULT_DATA_PATH, env="data")
log_path: Path = Field(_DEFAULT_LOG_PATH, env="log")
oauth_path: Union[OAuthFile, str, None] = Field(None, env="oauth")
data_path: Optional[Path] = Field(_DEFAULT_DATA_PATH, validation_alias="data")
log_path: Optional[Path] = Field(_DEFAULT_LOG_PATH, validation_alias="log")
oauth_path: Union[OAuthFile, str, None] = Field(None, validation_alias="oauth")

class Config:
env_file: Union[str, None] = first_existing_file(_DEFAULT_CONFIG_FILE_LOCATIONS)
env_file_encoding = "utf-8"
model_config = SettingsConfigDict(
env_file=first_existing_file(_DEFAULT_CONFIG_FILE_LOCATIONS),
env_file_encoding='utf-8'
)
# TODO: cleanup
# class Config:
# env_file: Union[str, None] = first_existing_file(_DEFAULT_CONFIG_FILE_LOCATIONS)
# env_file_encoding = "utf-8"

@validator("data_path", "log_path", pre=True)
def create_paths_if_do_not_exist(cls, v: Path):
# for key, path in values.items():
path = expand_and_resolve(v)
if path.is_file():
raise FileNotDirectoryError(
f"Configuration path: {str(path)} is a file not a directory."
)
elif not path.exists():
path.mkdir(parents=True)
return path

@validator("oauth_path")
@model_validator(mode="after")
def create_paths_if_do_not_exist(self):

def check_path(path: Path):
path = expand_and_resolve(path)
if path.is_file():
raise FileNotDirectoryError(
f"Configuration path: {str(path)} is a file not a directory."
)
elif not path.exists():
path.mkdir(parents=True)

check_path(self.data_path)
check_path(self.log_path)
return self

@field_validator("oauth_path")
def unpickle_oauth_path(cls, v):
if v is None:
return v
Expand All @@ -53,4 +63,4 @@ def unpickle_oauth_path(cls, v):
with open(path, "rb") as f:
deserialized_model = pickle.load(f)

return OAuthFile.parse_obj(deserialized_model)
return OAuthFile.model_validate(deserialized_model)
4 changes: 3 additions & 1 deletion hydroshare_on_jupyter/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path
from notebook.utils import url_path_join
# TODO: cleanup
# from notebook.utils import url_path_join
from jupyter_server.utils import url_path_join
import tornado
from ..websocket_handler import FileSystemEventWebSocketHandler
from ..server import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from pathlib import Path
from typing import Set, List, TYPE_CHECKING
from .fs_resource_map import LocalFSResourceMap, RemoteFSResourceMap
Expand Down Expand Up @@ -62,8 +62,8 @@ def from_resource_maps(
)


class AggregateFSResourceMapSyncStateCollection(BaseModel):
__root__: List[AggregateFSResourceMapSyncState]
class AggregateFSResourceMapSyncStateCollection(RootModel):
root: List[AggregateFSResourceMapSyncState]

@classmethod
def from_aggregate_map(
Expand All @@ -74,7 +74,7 @@ def from_aggregate_map(

res_intersection = set(lm) & set(rm)

return cls.parse_obj(
return cls.model_validate(
[
AggregateFSResourceMapSyncState.from_resource_maps(
local_resource_map=lm[res_id], remote_resource_map=rm[res_id]
Expand Down
90 changes: 55 additions & 35 deletions hydroshare_on_jupyter/models/api_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from pydantic import (
RootModel,
BaseModel,
Field,
StrictStr,
StrictBool,
constr,
validator,
field_validator,
ConfigDict,
StringConstraints,
)
from typing import List, Union
from typing import List, Union, Literal, Any, Annotated
from hsclient import Token

from .resource_type_enum import ResourceTypeEnum
Expand All @@ -15,8 +18,10 @@
class ModelNoExtra(BaseModel):
"""does not permit extra fields"""

class Config:
extra = "forbid"
model_config = ConfigDict(extra="forbid")
# TODO: cleanup - also cleanup imports above
# class Config:
# extra = "forbid"


class Boolean(BaseModel):
Expand All @@ -39,32 +44,42 @@ class OAuthCredentials(ModelNoExtra):

CredentialTypes = Union[StandardCredentials, OAuthCredentials]


class Credentials(BaseModel):
__root__: CredentialTypes = Field(...)

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False
) -> "DictStrAny":
d = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
# return contents of root key dropping it in the process
return d["__root__"]
Credentials = RootModel[CredentialTypes]

# TODO: cleanup
# class Credentials(RootModel, BaseModel):
# root: CredentialTypes = Field(...)
#
# def model_dump(
# self,
# *,
# mode: Literal['json', 'python'] | str = 'python',
# include: "IncEx" = None,
# exclude: "IncEx" = None,
# context: dict[str, Any] | None = None,
# by_alias: bool = False,
# exclude_unset: bool = False,
# exclude_defaults: bool = False,
# exclude_none: bool = False,
# round_trip: bool = False,
# warnings: bool | Literal['none', 'warn', 'error'] = True,
# serialize_as_any: bool = False,
# ) -> dict[str, Any]:
# d = super().model_dump(
# mode=mode,
# include=include,
# exclude=exclude,
# context=context,
# by_alias=by_alias,
# exclude_unset=exclude_unset,
# exclude_defaults=exclude_defaults,
# exclude_none=exclude_none,
# round_trip=round_trip,
# warnings=warnings,
# serialize_as_any=serialize_as_any,
# )
# # return contents of root key dropping it in the process
# return d["root"]


class Success(BaseModel):
Expand All @@ -83,18 +98,18 @@ class ResourceMetadata(BaseModel):
authors: List[str] = Field(...)

# NOTE: remove once https://github.com/hydroshare/hsclient/issues/23 has been resolved
@validator("authors", pre=True, always=True)
@field_validator("authors", mode="before")
def handle_null_author(cls, v):
return v or []

@validator("creator", pre=True, always=True)
@field_validator("creator", mode="before")
def handle_null_creator(cls, v):
return "" if v is None else v


class CollectionOfResourceMetadata(BaseModel):
class CollectionOfResourceMetadata(RootModel):
# from https://github.com/samuelcolvin/pydantic/issues/675#issuecomment-513029543
__root__: List[ResourceMetadata]
root: List[ResourceMetadata]


class ResourceCreationRequest(BaseModel):
Expand All @@ -110,9 +125,14 @@ class ResourceCreationRequest(BaseModel):
resource_type: ResourceTypeEnum


ResFileType = Annotated[str, StringConstraints(pattern=r"^((?!~|\.{2}).)*$")]


class ResourceFiles(BaseModel):
# str in list cannot contain .. or ~
files: List[constr(regex=r"^((?!~|\.{2}).)*$")] = Field(...)
files: List[ResFileType] = Field(...)

model_config = ConfigDict(regex_engine='python-re')


class DataDir(BaseModel):
Expand Down
64 changes: 38 additions & 26 deletions hydroshare_on_jupyter/models/oauth.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,44 @@
from pydantic import BaseModel
from pydantic import BaseModel, RootModel
from hsclient import Token

# typing imports
from typing import Tuple, Union
from typing import Tuple, Union, Literal, Any, Optional


class OAuthFile(BaseModel):
__root__: Tuple[Token, str]
OAuthFile = RootModel[Tuple[Token, str]]

def dict(
self,
*,
include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None,
by_alias: bool = False,
skip_defaults: bool = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False
) -> "DictStrAny":
d = super().dict(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
# drop __root__, return only inner model
return d["__root__"]
# TODO: cleanup - also fix the imports above
# class OAuthFile(RootModel, BaseModel):
# root: Tuple[Token, str]
# def model_dump(
# self,
# *,
# mode: Literal['json', 'python'] | str = 'python',
# include: "IncEx" = None,
# exclude: "IncEx" = None,
# context: dict[str, Any] | None = None,
# by_alias: bool = False,
# exclude_unset: bool = False,
# exclude_defaults: bool = False,
# exclude_none: bool = False,
# round_trip: bool = False,
# warnings: bool | Literal['none', 'warn', 'error'] = True,
# serialize_as_any: bool = False,
# ) -> dict[str, Any]:
# d = super().model_dump(
# mode=mode,
# include=include,
# exclude=exclude,
# context=context,
# by_alias=by_alias,
# exclude_unset=exclude_unset,
# exclude_defaults=exclude_defaults,
# exclude_none=exclude_none,
# round_trip=round_trip,
# warnings=warnings,
# serialize_as_any=serialize_as_any,
# )
# # drop __root__, return only inner model
# print(d)
# root, _ = d
# return self.root
Loading

0 comments on commit 206da75

Please sign in to comment.