Skip to content

Commit

Permalink
Adds data models + BE integrations for annotations
Browse files Browse the repository at this point in the history
1. Exposes as mixin in the BE
2. Adds GET/POST/PUT endpoints
3. Creates data models

Does not work with s3 yet.
  • Loading branch information
elijahbenizzy committed Oct 11, 2024
1 parent bf00101 commit b657f87
Show file tree
Hide file tree
Showing 11 changed files with 1,638 additions and 6 deletions.
183 changes: 181 additions & 2 deletions burr/tracking/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os.path
import sys
from datetime import datetime
from typing import Any, Optional, Sequence, Tuple, Type, TypeVar

import aiofiles
Expand All @@ -15,7 +16,14 @@
from burr.tracking.common import models
from burr.tracking.common.models import ChildApplicationModel
from burr.tracking.server import schema
from burr.tracking.server.schema import ApplicationLogs, ApplicationSummary, Step
from burr.tracking.server.schema import (
AnnotationCreate,
AnnotationOut,
AnnotationUpdate,
ApplicationLogs,
ApplicationSummary,
Step,
)

T = TypeVar("T")

Expand Down Expand Up @@ -59,6 +67,61 @@ async def indexing_jobs(
pass


class AnnotationsBackendMixin(abc.ABC):
@abc.abstractmethod
async def create_annotation(
self,
annotation: AnnotationCreate,
project_id: str,
partition_key: Optional[str],
app_id: str,
step_sequence_id: int,
) -> AnnotationOut:
"""Createse an annotation -- annotation has annotation data, the other pointers are given in the parameters.
:param annotation: Annotation object to create
:param partition_key: Partition key to associate with
:param project_id: Project ID to associate with
:param app_id: App ID to associate with
:param step_sequence_id: Step sequence ID to associate with
:return:
"""

@abc.abstractmethod
async def update_annotation(
self,
annotation: AnnotationUpdate,
project_id: str,
annotation_id: int,
) -> AnnotationOut:
"""Updates an annotation -- annotation has annotation data, the other pointers are given in the parameters.
:param annotation: Annotation object to update
:param project_id: Project ID to associate with
:param annotation_id: Annotation ID to update. We include this as we may have multiple...
:return: Updated annotation
"""

@abc.abstractmethod
async def get_annotations(
self,
project_id: str,
partition_key: Optional[str] = None,
app_id: Optional[str] = None,
step_sequence_id: Optional[int] = None,
) -> Sequence[AnnotationOut]:
"""Returns annotations for a given project, partition_key, app_id, and step sequence ID.
If these are None it does not filter by them.
:param project_id: Project ID to query for
:param partition_key: Partition key to query for
:param app_id: App ID to query for
:param step_sequence_id: Step sequence ID to query for
:return: Annotations
"""
pass


class SnapshottingBackendMixin(abc.ABC):
"""Mixin for backend that conducts snapshotting -- e.g. saves
the data to a file or database."""
Expand Down Expand Up @@ -188,7 +251,7 @@ def get_uri(project_id: str) -> str:
DEFAULT_PATH = os.path.expanduser("~/.burr")


class LocalBackend(BackendBase):
class LocalBackend(BackendBase, AnnotationsBackendMixin):
"""Quick implementation of a local backend for testing purposes. This is not a production backend.
To override the path, set a `burr_path` environment variable to the path you want to use.
Expand All @@ -197,6 +260,122 @@ class LocalBackend(BackendBase):
def __init__(self, path: str = DEFAULT_PATH):
self.path = path

def _get_annotation_path(self, project_id: str) -> str:
return os.path.join(self.path, project_id, "annotations.jsonl")

async def _load_project_annotations(self, project_id: str):
annotations_path = self._get_annotation_path(project_id)
annotations = []
if os.path.exists(annotations_path):
async with aiofiles.open(annotations_path) as f:
for line in await f.readlines():
annotations.append(AnnotationOut.parse_raw(line))
return annotations

async def create_annotation(
self,
annotation: AnnotationCreate,
project_id: str,
partition_key: Optional[str],
app_id: str,
step_sequence_id: int,
) -> AnnotationOut:
"""Creates an annotation by loading all annotations, finding the max ID, and then appending the new annotation.
This is not efficient but it's OK -- this is the local version and the number of annotations will be unlikely to be
huge.
:param annotation: Annotation to create
:param project_id: ID of the associated project
:param partition_key: Partition key to associate with
:param app_id: App ID to associate with
:param step_sequence_id: Step sequence ID to associate with
:return: The created annotation, complete with an ID + timestamps
"""
all_annotations = await self._load_project_annotations(project_id)
annotation_id = (
max([a.id for a in all_annotations], default=-1) + 1
) # get the ID, increment
annotation_out = AnnotationOut(
id=annotation_id,
project_id=project_id,
app_id=app_id,
partition_key=partition_key,
step_sequence_id=step_sequence_id,
created=datetime.now(),
updated=datetime.now(),
**annotation.dict(),
)
annotations_path = self._get_annotation_path(project_id)
async with aiofiles.open(annotations_path, "a") as f:
await f.write(annotation_out.json() + "\n")
return annotation_out

async def update_annotation(
self,
annotation: AnnotationUpdate,
project_id: str,
annotation_id: int,
) -> AnnotationOut:
"""Updates an annotation by loading all annotations, finding the annotation, updating it, and then writing it back.
Again, inefficient, but this is the local backend and we don't expect huge numbers of annotations.
:param annotation: Annotation to update -- this is just the update fields to the full annotation
:param project_id: ID of the associated project
:param annotation_id: ID of the associated annotation, created by the backend
:return: The updated annotation, complete with an ID + timestamps
"""
all_annotations = await self._load_project_annotations(project_id)
annotation_out = None
for idx, a in enumerate(all_annotations):
if a.id == annotation_id:
annotation_out = a
all_annotations[idx] = annotation_out.copy(
update={**annotation.dict(), "updated": datetime.now()}
)
break
if annotation_out is None:
raise fastapi.HTTPException(
status_code=404,
detail=f"Annotation: {annotation_id} from project: {project_id} not found",
)
annotations_path = self._get_annotation_path(project_id)
async with aiofiles.open(annotations_path, "w") as f:
for a in all_annotations:
await f.write(a.json() + "\n")
return annotation_out

async def get_annotations(
self,
project_id: str,
partition_key: Optional[str] = None,
app_id: Optional[str] = None,
step_sequence_id: Optional[int] = None,
) -> Sequence[AnnotationOut]:
"""Gets the annotation by loading all annotations and filtering by the parameters. Will return all annotations
that match. Only project is required.
:param project_id:
:param partition_key:
:param app_id:
:param step_sequence_id:
:return:
"""
annotation_path = self._get_annotation_path(project_id)
if not os.path.exists(annotation_path):
return []
annotations = []
async with aiofiles.open(annotation_path) as f:
for line in await f.readlines():
parsed = AnnotationOut.parse_raw(line)
if (
(partition_key is None or parsed.partition_key == partition_key)
and (app_id is None or parsed.app_id == app_id)
and (step_sequence_id is None or parsed.step_sequence_id == step_sequence_id)
):
annotations.append(parsed)
return annotations

async def list_projects(self, request: fastapi.Request) -> Sequence[schema.Project]:
out = []
if not os.path.exists(self.path):
Expand Down
76 changes: 73 additions & 3 deletions burr/tracking/server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import os
from contextlib import asynccontextmanager
from importlib.resources import files
from typing import Sequence
from typing import Optional, Sequence

from starlette import status

# TODO -- remove this, just for testing
from burr.log_setup import setup_logging
from burr.tracking.server.backend import BackendBase, IndexingBackendMixin, SnapshottingBackendMixin
from burr.tracking.server.backend import (
AnnotationsBackendMixin,
BackendBase,
IndexingBackendMixin,
SnapshottingBackendMixin,
)

setup_logging(logging.INFO)

Expand All @@ -23,7 +28,10 @@
from starlette.templating import Jinja2Templates

from burr.tracking.server import schema
from burr.tracking.server.schema import (
from burr.tracking.server.schema import ( # AnnotationUpdate,
AnnotationCreate,
AnnotationOut,
AnnotationUpdate,
ApplicationLogs,
ApplicationPage,
BackendSpec,
Expand Down Expand Up @@ -130,11 +138,13 @@ def is_ready():
def get_app_spec():
is_indexing_backend = isinstance(backend, IndexingBackendMixin)
is_snapshotting_backend = isinstance(backend, SnapshottingBackendMixin)
is_annotations_backend = isinstance(backend, AnnotationsBackendMixin)
supports_demos = backend.supports_demos()
return BackendSpec(
indexing=is_indexing_backend,
snapshotting=is_snapshotting_backend,
supports_demos=supports_demos,
supports_annotations=is_annotations_backend,
)


Expand Down Expand Up @@ -217,6 +227,66 @@ async def get_application_logs(
)


@app.post(
"/api/v0/{project_id}/{app_id}/{partition_key}/{sequence_id}/annotations",
response_model=AnnotationOut,
)
async def create_annotation(
request: Request,
project_id: str,
app_id: str,
partition_key: str,
sequence_id: int,
annotation: AnnotationCreate,
):
if partition_key == SENTINEL_PARTITION_KEY:
partition_key = None
spec = get_app_spec()
if not spec.supports_annotations:
return [] # empty default -- the case that we don't support annotations
return await backend.create_annotation(
annotation, project_id, partition_key, app_id, sequence_id
)


#
# # TODO -- take out these parameters cause we have the annotation ID
@app.put(
"/api/v0/{project_id}/{annotation_id}/update_annotations",
response_model=AnnotationOut,
)
async def update_annotation(
request: Request,
project_id: str,
annotation_id: int,
annotation: AnnotationUpdate,
):
return await backend.update_annotation(
annotation_id=annotation_id, annotation=annotation, project_id=project_id
)


@app.get("/api/v0/{project_id}/annotations", response_model=Sequence[AnnotationOut])
async def get_annotations(
request: Request,
project_id: str,
app_id: Optional[str] = None,
partition_key: Optional[str] = None,
step_sequence_id: Optional[int] = None,
):
# Handle the sentinel value for partition_key
if partition_key == SENTINEL_PARTITION_KEY:
partition_key = None
backend_spec = get_app_spec()

if not backend_spec.supports_annotations:
# makes it easier to wire through to the FE
return []

# Logic to retrieve the annotations
return await backend.get_annotations(project_id, partition_key, app_id, step_sequence_id)


@app.get("/api/v0/ready")
async def ready() -> bool:
return True
Expand Down
51 changes: 50 additions & 1 deletion burr/tracking/server/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import pydantic
from pydantic import fields
Expand Down Expand Up @@ -182,3 +182,52 @@ class BackendSpec(pydantic.BaseModel):
indexing: bool
snapshotting: bool
supports_demos: bool
supports_annotations: bool


class AnnotationDataPointer(pydantic.BaseModel):
type: Literal["state_field", "attribute"]
field_name: str # key of attribute/state field
span_id: Optional[
str
] # span_id if it's associated with a span, otherwise it's associated with an action


AllowedDataField = Literal["note", "ground_truth"]


class AnnotationObservation(pydantic.BaseModel):
data_fields: dict[str, Any]
thumbs_up_thumbs_down: Optional[bool]
data_pointers: List[AnnotationDataPointer]


class AnnotationCreate(pydantic.BaseModel):
"""Generic link for indexing job -- can be exposed in 'admin mode' in the UI"""

span_id: Optional[str]
step_name: str # Should be able to look it up but including for now
tags: List[str]
observations: List[AnnotationObservation]


class AnnotationUpdate(AnnotationCreate):
"""Generic link for indexing job -- can be exposed in 'admin mode' in the UI"""

# Identification for association
span_id: Optional[str] = None
tags: Optional[List[str]] = []
observations: List[AnnotationObservation]


class AnnotationOut(AnnotationCreate):
"""Generic link for indexing job -- can be exposed in 'admin mode' in the UI"""

id: int
# Identification for association
project_id: str # associated project ID
app_id: str
partition_key: Optional[str]
step_sequence_id: int
created: datetime.datetime
updated: datetime.datetime
Loading

0 comments on commit b657f87

Please sign in to comment.