Skip to content

Commit

Permalink
Updates model_type -> type
Browse files Browse the repository at this point in the history
This avoids a "protected namespace" pydantic warning
  • Loading branch information
elijahbenizzy committed Feb 16, 2024
1 parent f830c05 commit 514fed8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
12 changes: 6 additions & 6 deletions burr/tracking/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class IdentifyingModel(pydantic.BaseModel):
model_type: str
type: str


class ActionModel(IdentifyingModel):
Expand All @@ -28,7 +28,7 @@ class ActionModel(IdentifyingModel):
reads: list[str]
writes: list[str]
code: str
model_type: str = "action"
type: str = "action"

@staticmethod
def from_action(action: Action) -> "ActionModel":
Expand All @@ -55,7 +55,7 @@ class TransitionModel(IdentifyingModel):
from_: str
to: str
condition: str
model_type: str = "transition"
type: str = "transition"

@staticmethod
def from_transition(transition: Transition) -> "TransitionModel":
Expand All @@ -70,7 +70,7 @@ class ApplicationModel(IdentifyingModel):
entrypoint: str
actions: list[ActionModel]
transitions: list[TransitionModel]
model_type: str = "application"
type: str = "application"

@staticmethod
def from_application_graph(application_graph: ApplicationGraph) -> "ApplicationModel":
Expand All @@ -90,7 +90,7 @@ class BeginEntryModel(IdentifyingModel):
start_time: datetime.datetime
action: str
inputs: Dict[str, Any]
model_type: str = "begin_entry"
type: str = "begin_entry"


class EndEntryModel(IdentifyingModel):
Expand All @@ -101,4 +101,4 @@ class EndEntryModel(IdentifyingModel):
result: Optional[dict]
exception: Optional[str]
state: Dict[str, Any] # TODO -- consider logging updates to the state so we can recreate
model_type: str = "end_entry"
type: str = "end_entry"
16 changes: 6 additions & 10 deletions tests/tracking/test_local_tracking_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,16 @@ def test_application_tracks_end_to_end(tmpdir: str):
log_contents = [json.loads(item) for item in f.readlines()]
with open(graph_output) as f:
graph_contents = json.load(f)
assert graph_contents["model_type"] == "application"
assert graph_contents["type"] == "application"
app_model = ApplicationModel.parse_obj(graph_contents)
assert app_model.entrypoint == "counter"
assert app_model.actions[0].name == "counter"
assert app_model.actions[1].name == "result"
pre_run = [
BeginEntryModel.parse_obj(line)
for line in log_contents
if line["model_type"] == "begin_entry"
BeginEntryModel.parse_obj(line) for line in log_contents if line["type"] == "begin_entry"
]
post_run = [
EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry"
EndEntryModel.parse_obj(line) for line in log_contents if line["type"] == "end_entry"
]
assert len(pre_run) == 3
assert len(post_run) == 3
Expand All @@ -87,18 +85,16 @@ def test_application_tracks_end_to_end_broken(tmpdir: str):
log_contents = [json.loads(item) for item in f.readlines()]
with open(graph_output) as f:
graph_contents = json.load(f)
assert graph_contents["model_type"] == "application"
assert graph_contents["type"] == "application"
app_model = ApplicationModel.parse_obj(graph_contents)
assert app_model.entrypoint == "counter"
assert app_model.actions[0].name == "counter"
assert app_model.actions[1].name == "result"
pre_run = [
BeginEntryModel.parse_obj(line)
for line in log_contents
if line["model_type"] == "begin_entry"
BeginEntryModel.parse_obj(line) for line in log_contents if line["type"] == "begin_entry"
]
post_run = [
EndEntryModel.parse_obj(line) for line in log_contents if line["model_type"] == "end_entry"
EndEntryModel.parse_obj(line) for line in log_contents if line["type"] == "end_entry"
]
assert len(pre_run) == 2
assert len(post_run) == 2
Expand Down

0 comments on commit 514fed8

Please sign in to comment.