diff --git a/pyproject.toml b/pyproject.toml index 5094b782..ae7024d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lavague" -version = "1.0.13" +version = "1.0.14" description = "Selenium & Playwright code generation from text instructions" readme = "README.md" requires-python = ">=3.8" @@ -59,7 +59,10 @@ dependencies = [ "gradio==4.21.0", "ipython", "langchain==0.1.10", - "datasets" + "datasets", + "fastapi", + "uvicorn", + "starlette" ] [project.optional-dependencies] diff --git a/src/lavague/action_engine.py b/src/lavague/action_engine.py index f5764704..9c565803 100644 --- a/src/lavague/action_engine.py +++ b/src/lavague/action_engine.py @@ -75,6 +75,7 @@ def __init__( self.retriever = retriever self.prompt_template = prompt_template self.cleaning_function = cleaning_function + self.retrieved_context = "" def get_query_engine( self, html: str, streaming: bool = True @@ -123,6 +124,7 @@ def get_action(self, query: str, html: str, url: str = "") -> str: finally: source_nodes = self.get_nodes(query, html) retrieved_context = "\n".join(source_nodes) + self.retrieved_context = retrieved_context send_telemetry(self.llm.metadata.model_name, code, "", html, query, url, "action-engine", success, False, err, retrieved_context) return code @@ -165,7 +167,8 @@ def get_action_streaming(self, query: str, html: str, url: str = "") -> Generato finally: source_nodes = self.get_nodes(query, html) retrieved_context = "\n".join(source_nodes) - send_telemetry(self.llm.metadata.model_name, code, "", html, query, url, "action-engine", success, False, err, retrieved_context) + self.retrieved_context = retrieved_context + send_telemetry(self.llm.metadata.model_name, code, "", html, query, url, "action-engine", None, False, err, retrieved_context) def get_action_streaming_vscode(self, query: str, html: str, url: str) -> Generator[str, None, None]: from .telemetry import send_telemetry @@ -185,7 +188,7 @@ def get_action_streaming_vscode(self, query: str, html: str, url: str) -> Genera finally: source_nodes = self.get_nodes(query, html) retrieved_context = "\n".join(source_nodes) - send_telemetry(self.llm.metadata.model_name, full_text, "", html, query, url, "lavague-vscode", success, False, err, retrieved_context) + send_telemetry(self.llm.metadata.model_name, full_text, "", html, query, url, "lavague-vscode", None, False, err, retrieved_context) class TestActionEngine(BaseActionEngine): """ diff --git a/src/lavague/browser_server.py b/src/lavague/browser_server.py new file mode 100644 index 00000000..eac9f5ec --- /dev/null +++ b/src/lavague/browser_server.py @@ -0,0 +1,94 @@ +# server.py +from typing import Callable, Optional +from fastapi import FastAPI +from pydantic import BaseModel +from sqlalchemy import func +import uvicorn +import base64 +from lavague.driver import AbstractDriver +from lavague.format_utils import extract_code_from_funct, extract_imports_from_lines + +class Request(BaseModel): + code: Optional[str] = None + +driver_global: AbstractDriver = None +get_driver: Callable[[], AbstractDriver] + +# fastapi set-up +app = FastAPI() + +@app.get("/screenshot") +def scr(): + global driver_global + if driver_global is None: + driver_global = get_driver() + driver_global.getScreenshot("screenshot.png") + f = open("screenshot.png", "rb") + scr = base64.b64encode(f.read()) + return scr.decode("ascii") + +@app.post("/exec_code") +def exec_code(req: Request): + global driver_global + if driver_global is None: + driver_global = get_driver() + success = False + error = "" + driver_name, driver = driver_global.getDriver() # define driver for exec + exec(f"{driver_name.strip()} = driver") # define driver in case its name is different + source_code_lines = extract_code_from_funct(get_driver) + import_lines = extract_imports_from_lines(source_code_lines) + code_to_exec = f""" +{import_lines} +{req.code} +""" + try: + exec(code_to_exec) + success = True + except Exception as e: + error = repr(e) + success = False + return {"success": success, "error": error} + + +@app.get("/get_url") +def geturl(): + global driver_global + if driver_global is None: + driver_global = get_driver() + url = driver_global.getUrl() + return url + +@app.get("/get_html") +def gethtml(): + global driver_global + if driver_global is None: + driver_global = get_driver() + html = driver_global.getHtml() + return {"html": html} + +@app.get("/go_to") +def goto(url: str): + global driver_global + if driver_global is None: + driver_global = get_driver() + if url != driver_global.getUrl(): + driver_global.goTo(url) + return "" + +@app.get("/destroy") +def destroy(): + global driver_global + if driver_global is not None: + driver_global.destroy() + return "" + +@app.get("/") +def default(): + return "" + +def run_server(driver_func: Callable[[], AbstractDriver] = None, debug: bool = False): + global get_driver + if driver_func is not None: + get_driver = driver_func + uvicorn.run(app, host="127.0.0.1", port=16500, log_level="debug", workers=1, limit_concurrency=3) \ No newline at end of file diff --git a/src/lavague/cli/commands.py b/src/lavague/cli/commands.py index 22e422b6..7e85ee22 100644 --- a/src/lavague/cli/commands.py +++ b/src/lavague/cli/commands.py @@ -1,9 +1,10 @@ +from multiprocessing import Process from typing import Optional import click import warnings import os -from lavague.evaluator import Evaluator, SeleniumActionEvaluator +from lavague.evaluator import Evaluator from ..format_utils import extract_code_from_funct, extract_imports_from_lines @@ -11,6 +12,17 @@ def cli(): pass +@cli.command() +@click.pass_context +def driver_server(ctx): + """Start a server containing the driver""" + from .config import Config + from ..browser_server import run_server + from multiprocessing import Process + + config = Config.from_path(ctx.obj["config"]) + get_driver = config.get_driver + run_server(get_driver) @cli.command() @click.pass_context @@ -18,6 +30,8 @@ def launch(ctx): """Start a local gradio demo of lavague""" from .config import Config, Instructions from ..command_center import GradioDemo + from ..browser_server import run_server + from multiprocessing import Process config = Config.from_path(ctx.obj["config"]) if ctx.obj["instructions"] is not None: @@ -28,8 +42,11 @@ def launch(ctx): # We will just pass the get driver func name to the Gradio demo. # We will call this during driver initialization in init_driver() get_driver = config.get_driver - command_center = GradioDemo(action_engine, get_driver) + command_center = GradioDemo(action_engine, (get_driver)) + p = Process(target=run_server, args=(get_driver, ())) + p.start() command_center.run(instructions.url, instructions.instructions) + p.join() @cli.command() diff --git a/src/lavague/cli/main.py b/src/lavague/cli/main.py index a495d8d3..83d29e7a 100644 --- a/src/lavague/cli/main.py +++ b/src/lavague/cli/main.py @@ -42,6 +42,7 @@ def _lazy_load(self, cmd_name): cls=LazyGroup, lazy_subcommands={ "launch": "lavague.cli.commands.launch", + "driver_server": "lavague.cli.commands.driver_server", "build": "lavague.cli.commands.build", "eval": "lavague.cli.commands.evaluation", "test": "lavague.cli.commands.test", diff --git a/src/lavague/command_center.py b/src/lavague/command_center.py index a21e07a9..7990f3fd 100644 --- a/src/lavague/command_center.py +++ b/src/lavague/command_center.py @@ -2,17 +2,9 @@ from abc import ABC, abstractmethod import gradio as gr -try: - from selenium.webdriver.common.by import By # import used by generated selenium code - from selenium.webdriver.common.keys import ( - Keys, -) -except Exception as e: - pass - from .telemetry import send_telemetry from .action_engine import ActionEngine -from .driver import AbstractDriver +from .driver import RemoteDriver import base64 class CommandCenter(ABC): @@ -46,34 +38,29 @@ class GradioDemo(CommandCenter): def __init__(self, actionEngine: ActionEngine, get_driver: callable): self.actionEngine = actionEngine self.get_driver = get_driver - self.driver = None + self.driver = RemoteDriver("127.0.0.1", 16500) self.base_url = "" self.success = False self.error = "" def init_driver(self): def init_driver_impl(url): - driver = self.get_driver() - driver.goTo(url) - driver.getScreenshot("screenshot.png") - # This function is supposed to fetch and return the image from the URL. - # Placeholder function: replace with actual image fetching logic. - driver.destroy() + self.driver.goTo(url) + self.driver.getScreenshot("screenshot.png") return "screenshot.png" return init_driver_impl def process_instructions(self): def process_instructions_impl(query, url_input): - driver = self.get_driver() - driver.goTo(url_input) - state = driver.getHtml() + self.driver.goTo(url_input) + state = self.driver.getHtml() response = "" + print("Generating code...") for text in self.actionEngine.get_action_streaming(query, state, url_input): # do something with text as they arrive. response += text yield response - driver.destroy() return process_instructions_impl @@ -85,8 +72,11 @@ def telemetry(query, code, html, url_input): screenshot = base64.b64encode(scr.read()) except: pass - source_nodes = self.actionEngine.get_nodes(query, html) - retrieved_context = "\n".join(source_nodes) + try: + source_nodes = self.actionEngine.get_nodes(query, html) + retrieved_context = "\n".join(source_nodes) + except: + retrieved_context = self.actionEngine.retrieved_context send_telemetry( self.actionEngine.llm.metadata.model_name, code, @@ -105,27 +95,31 @@ def telemetry(query, code, html, url_input): def __exec_code(self): def exec_code(url_input, code, full_code): - driver_o = self.get_driver() - self.error = "" + html = self.driver.getHtml() code = self.actionEngine.cleaning_function(code) - driver_o.goTo(url_input) - html = driver_o.getHtml() - driver_name, driver = driver_o.getDriver() # define driver for exec - exec(f"{driver_name.strip()} = driver") # define driver in case its name is different try: - exec(code) - output = "Successful code execution" - status = """
Success!
""" - self.success = True - full_code += code - url_input = driver_o.getUrl() - driver_o.getScreenshot("screenshot.png") + res = self.driver.execCode(code) + if res["success"] == True: + html = self.driver.getHtml() + url_input = self.driver.getUrl() + self.driver.getScreenshot("screenshot.png") + output = "Successful code execution" + status = """Success!
""" + self.success = True + full_code += code + self.error = "" + else: + err = res["error"] + output = f"Error in code execution: {err}" + status = """Failure! Open the Debug tab for more information
""" + self.success = False + self.error = err except Exception as e: - output = f"Error in code execution: {str(e)}" + err = repr(e) + output = f"Error in code execution: {err}" status = """Failure! Open the Debug tab for more information
""" self.success = False - self.error = repr(e) - driver_o.destroy() + self.error = err return output, code, html, status, full_code, "screenshot.png", url_input return exec_code @@ -211,4 +205,4 @@ def run(self, base_url: str, instructions: List[str], server_port: int = 7860): self.__telemetry(), inputs=[text_area, code_display, full_html, url_input], ) - demo.launch(server_port=server_port, share=True, debug=True) + demo.launch(server_port=server_port, share=True, debug=True, max_threads=1) diff --git a/src/lavague/defaults.py b/src/lavague/defaults.py index da17f9b0..ea96881c 100644 --- a/src/lavague/defaults.py +++ b/src/lavague/defaults.py @@ -156,4 +156,5 @@ def default_get_playwright_driver() -> PlaywrightDriver: p = sync_playwright().start() browser = p.chromium.launch() page = browser.new_page() - return PlaywrightDriver(page, p) + page.set_default_timeout(5000) + return PlaywrightDriver(page) diff --git a/src/lavague/driver.py b/src/lavague/driver.py index 6f92d01f..b65d5638 100644 --- a/src/lavague/driver.py +++ b/src/lavague/driver.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Tuple from .format_utils import clean_html +import requests try: from selenium.webdriver.remote.webdriver import WebDriver @@ -64,6 +65,55 @@ def destroy(self) -> None: pass +class RemoteDriver(AbstractDriver): + def __init__(self, addr: str, port: int = 16500): + self.addr = addr + self.port = port + + def getDriver(self) -> Tuple[str, Any]: + return "driver", None + + def getUrl(self) -> str: + url = requests.get(f"http://{self.addr}:{self.port}/get_url") + url.raise_for_status() + url = url.text.strip('\"') + return url + + def goToUrlCode(self, url: str) -> str: + return f'' + + def goTo(self, url: str) -> None: + res = requests.get(f"http://{self.addr}:{self.port}/go_to", params={"url": url}) + res.raise_for_status() + + def getHtml(self, clean: bool = False) -> str: + html = requests.get(f"http://{self.addr}:{self.port}/get_html").json() + html = html["html"] + return clean_html(html) if clean else html + + def getScreenshot(self, filename: str) -> None: + import base64 + res_txt = requests.get(f"http://{self.addr}:{self.port}/screenshot") + res_txt.raise_for_status() + res_txt = res_txt.text + res = base64.b64decode(res_txt) + f = open(filename, "wb") + f.write(res) + f.close() + + def execCode(self, code: str) -> Any: + res = requests.post(f"http://{self.addr}:{self.port}/exec_code", json={"code": code}) + res.raise_for_status() + res = res.json() + return res + + def getDummyCode(self) -> str: + return '' + + def destroy(self) -> None: + res = requests.get(f"http://{self.addr}:{self.port}/destroy") + res.raise_for_status() + if SELENIUM_IMPORT: class SeleniumDriver(AbstractDriver): @@ -99,9 +149,8 @@ def destroy(self) -> None: if PLAYWRIGHT_IMPORT: class PlaywrightDriver(AbstractDriver): - def __init__(self, sync_playwright_page: Page, context: Playwright): + def __init__(self, sync_playwright_page: Page): self.driver = sync_playwright_page - self.context = context def getDriver(self) -> Tuple[str, Page]: return "page", self.driver @@ -126,4 +175,3 @@ def getDummyCode(self) -> str: def destroy(self) -> None: self.driver.close() - self.context.stop() diff --git a/src/lavague/format_utils.py b/src/lavague/format_utils.py index b41c6bde..c5a545a3 100644 --- a/src/lavague/format_utils.py +++ b/src/lavague/format_utils.py @@ -14,7 +14,6 @@ def extract_code_from_funct(funct: Callable) -> List[str]: line[nident:] for line in source_code_lines[:-1] ] # every line except the return - def extract_imports_from_lines(lines: List[str]) -> str: """Only keep import lines from python code lines and join them""" return "\n".join(