Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forms model getter #970

Merged
merged 4 commits into from
Nov 13, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions core/cat/experimental/form/cat_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ def __init__(self, cat) -> None:

self._errors: List[str] = []
self._missing_fields: List[str] = []
self.model_class = self.model_getter()
lucagobbi marked this conversation as resolved.
Show resolved Hide resolved

@property
def cat(self):
return self._cat

def model_getter(self):
return self.model_class

def submit(self, form_data) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -124,7 +128,7 @@ def next(self):
# If the state is INCOMPLETE, execute model update
# (and change state based on validation result)
if self._state == CatFormState.INCOMPLETE:
self._model = self.update()
self.update()

# If state is COMPLETE, ask confirm (or execute action directly)
if self._state == CatFormState.COMPLETE:
Expand All @@ -145,12 +149,11 @@ def update(self):
json_details = self.sanitize(json_details)

# model merge old and new
new_model = self._model | json_details
self._model = self._model | json_details

# Validate new_details
new_model = self.validate(new_model)
self.validate()

return new_model

def message(self):
state_methods = {
Expand Down Expand Up @@ -219,7 +222,7 @@ def extraction_prompt(self):
# JSON structure
# BaseModel.__fields__['my_field'].type_
JSON_structure = "{"
for field_name, field in self.model_class.model_fields.items():
for field_name, field in self.model_getter().model_fields.items():
if field.description:
description = field.description
else:
Expand Down Expand Up @@ -260,15 +263,15 @@ def sanitize(self, model):
return model

# Validate model
def validate(self, model):
def validate(self):
self._missing_fields = []
self._errors = []

try:
# INFO TODO: In this case the optional fields are always ignored

# Attempts to create the model object to update the default values and validate it
model = self.model_class(**model).model_dump(mode="json")
self.model_getter()(**self._model).model_dump(mode="json")

# If model is valid change state to COMPLETE
self._state = CatFormState.COMPLETE
Expand All @@ -281,9 +284,7 @@ def validate(self, model):
self._missing_fields.append(field_name)
else:
self._errors.append(f'{field_name}: {error_message["msg"]}')
del model[field_name]
del self._model[field_name]

# Set state to INCOMPLETE
self._state = CatFormState.INCOMPLETE

return model
self._state = CatFormState.INCOMPLETE
Loading