Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update API schema + SambaNova Copilot #28

Merged
merged 12 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff
pip install ruff==0.8.0
- name: Run Ruff lint check
run : ruff check --output-format=github .
- name: Run Ruff format check
Expand Down
29 changes: 29 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,32 @@ jobs:
run: |
cd llama31-local-copilot
poetry run pytest tests
sambanova-copilot:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Install Poetry
uses: snok/install-poetry@v1
with:
version: 1.8.3
virtualenvs-create: true
virtualenvs-in-project: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
cd sambanova
python -m pip install --upgrade pip
poetry install
- name: Run Pytest
env:
SAMBANOVA_API_KEY: ${{ secrets.SAMBANOVA_API_KEY }}
run: |
cd sambanova
poetry run pytest tests
130 changes: 82 additions & 48 deletions README.md

Large diffs are not rendered by default.

59 changes: 55 additions & 4 deletions common/common/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, JsonValue, field_validator, model_validator
from enum import Enum
import json

Expand All @@ -11,13 +11,38 @@ class RoleEnum(str, Enum):
tool = "tool"


class LlmFunctionCallResult(BaseModel):
class ChartParameters(BaseModel):
chartType: Literal["line", "bar", "scatter"]
xKey: str
yKey: list[str]


class DataFormat(BaseModel):
"""Describe the format of the data, and how it should be handled."""

type: Literal["text", "table", "chart"] | None = None
chart_params: ChartParameters | None = None


class DataContent(BaseModel):
content: JsonValue = Field(
description="The data content, which must be JSON-serializable. Can be a primitive type (str, int, float, bool), list, or dict." # noqa: E501
)
data_format: DataFormat | None = Field(
default=None,
description="Optional. How the data should be parsed. If not provided, a best-effort attempt will be made to automatically determine the data format.", # noqa: E501
)


class LlmClientFunctionCallResult(BaseModel):
"""Contains the result of a function call made against a client."""

role: RoleEnum = RoleEnum.tool
function: str = Field(description="The name of the called function.")
input_arguments: dict[str, Any] | None = Field(
default=None, description="The input arguments passed to the function"
)
content: str = Field(description="The result of the function call.")
data: list[DataContent] = Field(description="The content of the function call.")


class LlmFunctionCall(BaseModel):
Expand Down Expand Up @@ -80,7 +105,7 @@ class Widget(BaseModel):


class AgentQueryRequest(BaseModel):
messages: list[LlmFunctionCallResult | LlmMessage] = Field(
messages: list[LlmClientFunctionCallResult | LlmMessage] = Field(
description="A list of messages to submit to the copilot."
)
context: str | list[RawContext] | None = Field(
Expand Down Expand Up @@ -131,3 +156,29 @@ class FunctionCallSSEData(BaseModel):
class FunctionCallSSE(BaseSSE):
event: Literal["copilotFunctionCall"] = "copilotFunctionCall"
data: FunctionCallSSEData


class StatusUpdateSSEData(BaseModel):
eventType: Literal["INFO", "WARNING", "ERROR"]
message: str
group: Literal["reasoning"] = "reasoning"
details: list[dict[str, str | int | float | None]] | None = None

@model_validator(mode="before")
@classmethod
def exclude_fields(cls, values):
# Exclude these fields from being in the "details" field. (since this
# pollutes the JSON output)
_exclude_fields = []

if details := values.get("details"):
for detail in details:
for key in list(detail.keys()):
if key.lower() in _exclude_fields:
detail.pop(key, None)
return values


class StatusUpdateSSE(BaseSSE):
event: Literal["copilotStatusUpdate"] = "copilotStatusUpdate"
data: StatusUpdateSSEData
114 changes: 114 additions & 0 deletions common/common/testing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,118 @@
from ast import literal_eval
from pydantic import BaseModel


class CopilotEvent(BaseModel):
event_type: str
content: str | dict


class CopilotResponse:
def __init__(self, event_stream: str):
self.events: list = []
self.index = 0
self.event_stream = event_stream
self.parse_event_stream()

def parse_event_stream(self):
captured_message_chunks = ""
event_name = ""
lines = self.event_stream.split("\n")
for line in lines:
if line.startswith("event:"):
event_type = line.split("event:")[1].strip()
if event_type == "copilotMessageChunk" and line.startswith("data:"):
event_name = "copilotMessageChunk"
data_payload = line.split("data:")[1].strip()
data_dict_ = literal_eval(data_payload)
captured_message_chunks += data_dict_["delta"]
elif event_type == "copilotFunctionCall" and line.startswith("data:"):
event_name = "copilotFunctionCall"
data_payload = line.split("data:")[1].strip()
data_dict_ = literal_eval(data_payload)
self.events.append(
CopilotEvent(event_type=event_name, content=data_dict_)
)
elif event_type == "copilotStatusUpdate" and line.startswith("data:"):
event_name = "copilotStatusUpdate"
data_payload = line.split("data:")[1].strip()
data_dict_ = literal_eval(data_payload)
self.events.append(
CopilotEvent(event_type=event_name, content=data_dict_)
)

self.events.append(
CopilotEvent(event_type="copilotMessage", content=captured_message_chunks)
)

def __iter__(self):
return self

def __next__(self):
if self.index < len(self.events):
event = self.events[self.index]
self.index += 1
return event
else:
raise StopIteration

def starts_with(self, event_type: str, content_contains: str):
self.index = 0
assert self.events[self.index].event_type == event_type
assert content_contains in str(self.events[self.index].content)
return self

def then(self, event_type: str, content_contains: str):
self.index += 1
assert self.events[self.index].event_type == event_type
assert content_contains in str(self.events[self.index].content)
return self

def and_(self, content_contains: str):
assert content_contains in str(self.events[self.index].content)
return self

def and_not(self, content_contains: str):
assert content_contains not in str(self.events[self.index].content)
return self

def then_not(self, event_type: str, content_contains: str):
self.index += 1
assert self.events[self.index].event_type != event_type
assert content_contains not in str(self.events[self.index].content)
return self

def then_ignore(self):
self.index += 1
return self

def ends_with(self, event_type: str, content_contains: str):
self.index = len(self.events) - 1
assert self.events[self.index].event_type == event_type
assert content_contains in str(self.events[self.index].content)
return self

def ends_with_not(self, event_type: str, content_contains: str):
self.index = len(self.events) - 1
assert self.events[self.index].event_type == event_type
assert content_contains not in str(self.events[self.index].content)
return self

def has_any(self, event_type: str, content_contains: str):
assert any(
event_type == event.event_type and content_contains in str(event.content)
for event in self.events
)
return self

def has_all(self, copilot_events: list[CopilotEvent]):
assert all(
copilot_event.event_type == event.event_type
and copilot_event.content == event.content
for copilot_event in copilot_events
for event in self.events
)
return self


def capture_stream_response(event_stream: str) -> tuple[str, str]:
Expand Down
2 changes: 1 addition & 1 deletion common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pydantic = "^2.9.2"


[tool.poetry.group.dev.dependencies]
ruff = "^0.6.9"
ruff = "^0.8.0"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading