Skip to content

Commit

Permalink
agent-studio api
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrunel committed Oct 10, 2024
1 parent ac9bdff commit ca5594f
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
NoPageException,
)

from lavague.sdk.action.navigation import WebNavigationAction, NavigationCommand
from lavague.sdk.action import ActionStatus
from selenium.common.exceptions import (
NoSuchElementException,
TimeoutException,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
Element outer HTML: {outer_html}
Test:"""


class QASeleniumExporter(PythonSeleniumExporter):
def __init__(self, model: str = "gpt-4o", time_between_actions: float = 2.5):
self.model = model
Expand Down Expand Up @@ -128,19 +129,23 @@ def export(self, trajectory: TrajectoryData, scenario: str) -> str:
outer_html: str = output.outer_html

prompt = PROMPT_TEMPLATE.format(
test_specs=scenario,
test_specs=scenario,
context=context,
description=description,
text=text,
outer_html=outer_html,
)
test_response = completion(
model = self.model,
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt}
]
).choices[0].message.content
test_response = (
completion(
model=self.model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
)
.choices[0]
.message.content
)

generated_asserts: str = extract_code_block(test_response)
translated_action_lines.append(generated_asserts)
Expand All @@ -154,11 +159,12 @@ def export(self, trajectory: TrajectoryData, scenario: str) -> str:
translated_actions_str: str = self.merge_code(*translated_actions)
return self.merge_code(setup, translated_actions_str, teardown)


def extract_code_block(code_block):
pattern = r"```(?:python)?\n(.*?)\n```"
match = re.search(pattern, code_block, re.DOTALL)

if match:
return match.group(1).strip()
else:
return "No code block found"
return "No code block found"
2 changes: 2 additions & 0 deletions lavague-sdk/lavague/sdk/action/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from lavague.sdk.action.base import (
Action,
Instruction,
EngineType,
ActionType,
ActionStatus,
ActionParser,
Expand Down
13 changes: 13 additions & 0 deletions lavague-sdk/lavague/sdk/action/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ class ActionType(str, Enum):
EXTRACTION = "web_extraction"


class EngineType(str, Enum):
NAVIGATION = "Navigation Engine"
EXTRACTION = "Element Extraction Engine"
CONTROLS = "Navigation Controls"
COMPLETE = "COMPLETE"


T = TypeVar("T")


Expand Down Expand Up @@ -55,6 +62,12 @@ def parse(self, action_dict: Dict) -> Action:
return Action.parse(action_dict)


class Instruction(BaseModel):
chain_of_toughts: str
engine: EngineType
engine_instruction: str


class UnhandledTypeException(Exception):
pass

Expand Down
1 change: 1 addition & 0 deletions lavague-sdk/lavague/sdk/base_driver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def execute(self, action: NavigationOutput) -> None:
raise NotImplementedError(
f"Action {action.navigation_command} not implemented"
)
self.wait_for_idle()

@abstractmethod
def destroy(self) -> None:
Expand Down
27 changes: 25 additions & 2 deletions lavague-sdk/lavague/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Optional, Tuple

import requests
from lavague.sdk.action import DEFAULT_PARSER, ActionParser
from lavague.sdk.action import DEFAULT_PARSER, ActionParser, Instruction, Action
from lavague.sdk.trajectory import Trajectory
from lavague.sdk.trajectory.controller import TrajectoryController
from lavague.sdk.trajectory.model import StepCompletion
Expand Down Expand Up @@ -86,7 +86,30 @@ def next_step(self, run_id: str) -> StepCompletion:
f"/runs/{run_id}/step",
"POST",
)
return StepCompletion.model_validate_json(content)
return StepCompletion.from_data(content)

def generate_instruction(self, run_id: str) -> Instruction:
content = self.request_api(
f"/runs/{run_id}/step/instruction",
"POST",
)
return Instruction.model_validate_json(content)

def generate_action(self, run_id: str, instruction: Instruction) -> StepCompletion:
content = self.request_api(
f"/runs/{run_id}/step/action",
"POST",
instruction.model_dump(),
)
return StepCompletion.from_data(content)

def execute_action(self, run_id: str, action: StepCompletion) -> StepCompletion:
content = self.request_api(
f"/runs/{run_id}/step/execution",
"POST",
action.model_dump(),
)
return StepCompletion.from_data(content)

def stop(self, run_id: str) -> None:
self.request_api(
Expand Down
40 changes: 32 additions & 8 deletions lavague-sdk/lavague/sdk/trajectory/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from enum import Enum
from typing import Any, Dict, List, Tuple, Optional
from pydantic import BaseModel, SerializeAsAny
from lavague.sdk.action import Action
from lavague.sdk.action import Action, ActionParser
from lavague.sdk.action.base import DEFAULT_PARSER
from pydantic import model_validator
from pydantic_core import from_json


class RunStatus(str, Enum):
STARTING = "starting"
Expand Down Expand Up @@ -31,32 +33,54 @@ class TrajectoryData(BaseModel):
actions: List[SerializeAsAny[Action]]
error_msg: Optional[str] = None

@model_validator(mode='before')
@model_validator(mode="before")
@classmethod
def deserialize_actions(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if 'actions' in values:
actions = values['actions']
if "actions" in values:
actions = values["actions"]
deserialized_actions = []
for action_data in actions:
if isinstance(action_data, Action):
deserialized_actions.append(action_data)
continue
action_type = action_data.get('action_type')
action_type = action_data.get("action_type")
if action_type:
action_class = DEFAULT_PARSER.engine_action_builders.get(action_type, Action)
action_class = DEFAULT_PARSER.engine_action_builders.get(
action_type, Action
)
deserialized_action = action_class.parse(action_data)
deserialized_actions.append(deserialized_action)
else:
deserialized_actions.append(Action.parse(action_data))
values['actions'] = deserialized_actions
values["actions"] = deserialized_actions
return values

def write_to_file(self, file_path: str):
json_model = self.model_dump_json(indent=2)
with open(file_path, "w", encoding="utf-8") as file:
file.write(json_model)


class StepCompletion(BaseModel):
run_status: RunStatus
action: Optional[Action]
run_mode: RunMode
run_mode: RunMode

@classmethod
def from_data(
cls,
data: str | bytes | bytearray,
parser: ActionParser = DEFAULT_PARSER,
):
obj = from_json(data)
return cls.from_dict(obj, parser)

@classmethod
def from_dict(
cls,
data: Dict,
parser: ActionParser = DEFAULT_PARSER,
):
action = data.get("action")
action = parser.parse(action) if action else None
return cls.model_validate({**data, "action": action})

0 comments on commit ca5594f

Please sign in to comment.