Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fastapi remote srv #213

Merged
merged 11 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "lavague"
version = "1.0.13"
version = "1.0.14"
description = "Selenium & Playwright code generation from text instructions"
readme = "README.md"
requires-python = ">=3.8"
Expand Down Expand Up @@ -59,7 +59,10 @@ dependencies = [
"gradio==4.21.0",
"ipython",
"langchain==0.1.10",
"datasets"
"datasets",
"fastapi",
"uvicorn",
"starlette"
]

[project.optional-dependencies]
Expand Down
7 changes: 5 additions & 2 deletions src/lavague/action_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
self.retriever = retriever
self.prompt_template = prompt_template
self.cleaning_function = cleaning_function
self.retrieved_context = ""

def get_query_engine(
self, html: str, streaming: bool = True
Expand Down Expand Up @@ -123,6 +124,7 @@ def get_action(self, query: str, html: str, url: str = "") -> str:
finally:
source_nodes = self.get_nodes(query, html)
retrieved_context = "\n".join(source_nodes)
self.retrieved_context = retrieved_context
send_telemetry(self.llm.metadata.model_name, code, "", html, query, url, "action-engine", success, False, err, retrieved_context)
return code

Expand Down Expand Up @@ -165,7 +167,8 @@ def get_action_streaming(self, query: str, html: str, url: str = "") -> Generato
finally:
source_nodes = self.get_nodes(query, html)
retrieved_context = "\n".join(source_nodes)
send_telemetry(self.llm.metadata.model_name, code, "", html, query, url, "action-engine", success, False, err, retrieved_context)
self.retrieved_context = retrieved_context
send_telemetry(self.llm.metadata.model_name, code, "", html, query, url, "action-engine", None, False, err, retrieved_context)

def get_action_streaming_vscode(self, query: str, html: str, url: str) -> Generator[str, None, None]:
from .telemetry import send_telemetry
Expand All @@ -185,7 +188,7 @@ def get_action_streaming_vscode(self, query: str, html: str, url: str) -> Genera
finally:
source_nodes = self.get_nodes(query, html)
retrieved_context = "\n".join(source_nodes)
send_telemetry(self.llm.metadata.model_name, full_text, "", html, query, url, "lavague-vscode", success, False, err, retrieved_context)
send_telemetry(self.llm.metadata.model_name, full_text, "", html, query, url, "lavague-vscode", None, False, err, retrieved_context)

class TestActionEngine(BaseActionEngine):
"""
Expand Down
94 changes: 94 additions & 0 deletions src/lavague/browser_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# server.py
from typing import Callable, Optional
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import func
import uvicorn
import base64
from lavague.driver import AbstractDriver
from lavague.format_utils import extract_code_from_funct, extract_imports_from_lines

class Request(BaseModel):
code: Optional[str] = None

driver_global: AbstractDriver = None
get_driver: Callable[[], AbstractDriver]

# fastapi set-up
app = FastAPI()

@app.get("/screenshot")
def scr():
global driver_global
if driver_global is None:
driver_global = get_driver()
driver_global.getScreenshot("screenshot.png")
f = open("screenshot.png", "rb")
scr = base64.b64encode(f.read())
return scr.decode("ascii")

@app.post("/exec_code")
def exec_code(req: Request):
global driver_global
if driver_global is None:
driver_global = get_driver()
success = False
error = ""
driver_name, driver = driver_global.getDriver() # define driver for exec
exec(f"{driver_name.strip()} = driver") # define driver in case its name is different
source_code_lines = extract_code_from_funct(get_driver)
import_lines = extract_imports_from_lines(source_code_lines)
code_to_exec = f"""
{import_lines}
{req.code}
"""
try:
exec(code_to_exec)
success = True
except Exception as e:
error = repr(e)
success = False
return {"success": success, "error": error}


@app.get("/get_url")
def geturl():
global driver_global
if driver_global is None:
driver_global = get_driver()
url = driver_global.getUrl()
return url

@app.get("/get_html")
def gethtml():
global driver_global
if driver_global is None:
driver_global = get_driver()
html = driver_global.getHtml()
return {"html": html}

@app.get("/go_to")
def goto(url: str):
global driver_global
if driver_global is None:
driver_global = get_driver()
if url != driver_global.getUrl():
driver_global.goTo(url)
return ""

@app.get("/destroy")
def destroy():
global driver_global
if driver_global is not None:
driver_global.destroy()
return ""

@app.get("/")
def default():
return ""

def run_server(driver_func: Callable[[], AbstractDriver] = None, debug: bool = False):
global get_driver
if driver_func is not None:
get_driver = driver_func
uvicorn.run(app, host="127.0.0.1", port=16500, log_level="debug", workers=1, limit_concurrency=3)
21 changes: 19 additions & 2 deletions src/lavague/cli/commands.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,37 @@
from multiprocessing import Process
from typing import Optional
import click
import warnings
import os

from lavague.evaluator import Evaluator, SeleniumActionEvaluator
from lavague.evaluator import Evaluator
from ..format_utils import extract_code_from_funct, extract_imports_from_lines


@click.group()
def cli():
pass

@cli.command()
@click.pass_context
def driver_server(ctx):
"""Start a server containing the driver"""
from .config import Config
from ..browser_server import run_server
from multiprocessing import Process

config = Config.from_path(ctx.obj["config"])
get_driver = config.get_driver
run_server(get_driver)

@cli.command()
@click.pass_context
def launch(ctx):
"""Start a local gradio demo of lavague"""
from .config import Config, Instructions
from ..command_center import GradioDemo
from ..browser_server import run_server
from multiprocessing import Process

config = Config.from_path(ctx.obj["config"])
if ctx.obj["instructions"] is not None:
Expand All @@ -28,8 +42,11 @@ def launch(ctx):
# We will just pass the get driver func name to the Gradio demo.
# We will call this during driver initialization in init_driver()
get_driver = config.get_driver
command_center = GradioDemo(action_engine, get_driver)
command_center = GradioDemo(action_engine, (get_driver))
p = Process(target=run_server, args=(get_driver, ()))
p.start()
command_center.run(instructions.url, instructions.instructions)
p.join()


@cli.command()
Expand Down
1 change: 1 addition & 0 deletions src/lavague/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _lazy_load(self, cmd_name):
cls=LazyGroup,
lazy_subcommands={
"launch": "lavague.cli.commands.launch",
"driver_server": "lavague.cli.commands.driver_server",
"build": "lavague.cli.commands.build",
"eval": "lavague.cli.commands.evaluation",
"test": "lavague.cli.commands.test",
Expand Down
72 changes: 33 additions & 39 deletions src/lavague/command_center.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,9 @@
from abc import ABC, abstractmethod
import gradio as gr

try:
from selenium.webdriver.common.by import By # import used by generated selenium code
from selenium.webdriver.common.keys import (
Keys,
)
except Exception as e:
pass

from .telemetry import send_telemetry
from .action_engine import ActionEngine
from .driver import AbstractDriver
from .driver import RemoteDriver
import base64

class CommandCenter(ABC):
Expand Down Expand Up @@ -46,34 +38,29 @@ class GradioDemo(CommandCenter):
def __init__(self, actionEngine: ActionEngine, get_driver: callable):
self.actionEngine = actionEngine
self.get_driver = get_driver
self.driver = None
self.driver = RemoteDriver("127.0.0.1", 16500)
self.base_url = ""
self.success = False
self.error = ""

def init_driver(self):
def init_driver_impl(url):
driver = self.get_driver()
driver.goTo(url)
driver.getScreenshot("screenshot.png")
# This function is supposed to fetch and return the image from the URL.
# Placeholder function: replace with actual image fetching logic.
driver.destroy()
self.driver.goTo(url)
self.driver.getScreenshot("screenshot.png")
return "screenshot.png"

return init_driver_impl

def process_instructions(self):
def process_instructions_impl(query, url_input):
driver = self.get_driver()
driver.goTo(url_input)
state = driver.getHtml()
self.driver.goTo(url_input)
state = self.driver.getHtml()
response = ""
print("Generating code...")
for text in self.actionEngine.get_action_streaming(query, state, url_input):
# do something with text as they arrive.
response += text
yield response
driver.destroy()

return process_instructions_impl

Expand All @@ -85,8 +72,11 @@ def telemetry(query, code, html, url_input):
screenshot = base64.b64encode(scr.read())
except:
pass
source_nodes = self.actionEngine.get_nodes(query, html)
retrieved_context = "\n".join(source_nodes)
try:
source_nodes = self.actionEngine.get_nodes(query, html)
retrieved_context = "\n".join(source_nodes)
except:
retrieved_context = self.actionEngine.retrieved_context
send_telemetry(
self.actionEngine.llm.metadata.model_name,
code,
Expand All @@ -105,27 +95,31 @@ def telemetry(query, code, html, url_input):

def __exec_code(self):
def exec_code(url_input, code, full_code):
driver_o = self.get_driver()
self.error = ""
html = self.driver.getHtml()
code = self.actionEngine.cleaning_function(code)
driver_o.goTo(url_input)
html = driver_o.getHtml()
driver_name, driver = driver_o.getDriver() # define driver for exec
exec(f"{driver_name.strip()} = driver") # define driver in case its name is different
try:
exec(code)
output = "Successful code execution"
status = """<p style="color: green; font-size: 20px; font-weight: bold;">Success!</p>"""
self.success = True
full_code += code
url_input = driver_o.getUrl()
driver_o.getScreenshot("screenshot.png")
res = self.driver.execCode(code)
if res["success"] == True:
html = self.driver.getHtml()
url_input = self.driver.getUrl()
self.driver.getScreenshot("screenshot.png")
output = "Successful code execution"
status = """<p style="color: green; font-size: 20px; font-weight: bold;">Success!</p>"""
self.success = True
full_code += code
self.error = ""
else:
err = res["error"]
output = f"Error in code execution: {err}"
status = """<p style="color: red; font-size: 20px; font-weight: bold;">Failure! Open the Debug tab for more information</p>"""
self.success = False
self.error = err
except Exception as e:
output = f"Error in code execution: {str(e)}"
err = repr(e)
output = f"Error in code execution: {err}"
status = """<p style="color: red; font-size: 20px; font-weight: bold;">Failure! Open the Debug tab for more information</p>"""
self.success = False
self.error = repr(e)
driver_o.destroy()
self.error = err
return output, code, html, status, full_code, "screenshot.png", url_input

return exec_code
Expand Down Expand Up @@ -211,4 +205,4 @@ def run(self, base_url: str, instructions: List[str], server_port: int = 7860):
self.__telemetry(),
inputs=[text_area, code_display, full_html, url_input],
)
demo.launch(server_port=server_port, share=True, debug=True)
demo.launch(server_port=server_port, share=True, debug=True, max_threads=1)
3 changes: 2 additions & 1 deletion src/lavague/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,5 @@ def default_get_playwright_driver() -> PlaywrightDriver:
p = sync_playwright().start()
browser = p.chromium.launch()
page = browser.new_page()
return PlaywrightDriver(page, p)
page.set_default_timeout(5000)
return PlaywrightDriver(page)
Loading
Loading