diff --git a/README.md b/README.md index 8bd7652..66ef65c 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Run the server If you don't have public ip and domain name, you can use `ngrok` or similar services to get a https address to the api. The specify the server url in the `wcgw` command like so -`wcgw --server-url https://your-url/register` +`wcgw --server-url https://your-url/v1/register` # [Optional] Local shell access with openai API key diff --git a/gpt_action_json_schema.json b/gpt_action_json_schema.json index 02c72b4..0c5a44b 100644 --- a/gpt_action_json_schema.json +++ b/gpt_action_json_schema.json @@ -50,6 +50,46 @@ } } }, + "/v1/reset_shell": { + "post": { + "x-openai-isConsequential": false, + "summary": "Reset Shell", + "operationId": "reset_shell_v1_reset_shell_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ResetShellWithUUID" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "string", + "title": "Response Reset Shell V1 Reset Shell Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/bash_command": { "post": { "x-openai-isConsequential": false, @@ -228,6 +268,29 @@ "type": "object", "title": "HTTPValidationError" }, + "ResetShellWithUUID": { + "properties": { + "should_reset": { + "type": "boolean", + "enum": [ + true + ], + "const": true, + "title": "Should Reset", + "default": true + }, + "user_id": { + "type": "string", + "format": "uuid", + "title": "User Id" + } + }, + "type": "object", + "required": [ + "user_id" + ], + "title": "ResetShellWithUUID" + }, "ValidationError": { "properties": { "loc": { diff --git a/gpt_instructions.txt b/gpt_instructions.txt index 9b3430b..5a4abff 100644 --- a/gpt_instructions.txt +++ b/gpt_instructions.txt @@ -19,6 +19,7 @@ Instructions for `BashCommand`: Instructions for `Write File` - Write content to a file. Provide file path and content. Use this instead of BashCommand for writing files. +- This doesn't create any directories, please create directories using `mkdir -p` BashCommand. - Important: all relative paths are relative to last CWD. Instructions for `BashInteraction` @@ -26,6 +27,9 @@ Instructions for `BashInteraction` - Special keys like arrows, interrupts, enter, etc. - Send text input to the running program. +Instructions for `ResetShell` +- Resets the shell. Use only if all interrupts and prompt reset attempts have failed repeatedly. + --- Always critically think and debate with yourself to solve the problem. Understand the context and the code by reading as much resources as possible before writing a single piece of code. diff --git a/pyproject.toml b/pyproject.toml index ce12384..e63fae0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "websockets>=13.1", "pydantic>=2.9.2", "semantic-version>=2.10.0", + "nltk>=3.9.1", ] [project.urls] diff --git a/src/wcgw/client/basic.py b/src/wcgw/client/basic.py index 12c2caf..987e303 100644 --- a/src/wcgw/client/basic.py +++ b/src/wcgw/client/basic.py @@ -20,7 +20,7 @@ from typer import Typer import uuid -from ..types_ import BashCommand, BashInteraction, ReadImage, Writefile +from ..types_ import BashCommand, BashInteraction, ReadImage, Writefile, ResetShell from .common import Models, discard_input from .common import CostData, History @@ -177,6 +177,10 @@ def loop( openai.pydantic_function_tool( ReadImage, description="Read an image from the shell." ), + openai.pydantic_function_tool( + ResetShell, + description="Resets the shell. Use only if all interrupts and prompt reset attempts have failed repeatedly.", + ), ] uname_sysname = os.uname().sysname uname_machine = os.uname().machine diff --git a/src/wcgw/client/tools.py b/src/wcgw/client/tools.py index e26bda3..cf9a6c0 100644 --- a/src/wcgw/client/tools.py +++ b/src/wcgw/client/tools.py @@ -18,7 +18,6 @@ ) import uuid from pydantic import BaseModel, TypeAdapter -import semantic_version import typer from websockets.sync.client import connect as syncconnect @@ -41,8 +40,8 @@ ChatCompletionMessage, ParsedChatCompletionMessage, ) - -from ..types_ import Writefile +from nltk.metrics.distance import edit_distance +from ..types_ import FileEditFindReplace, ResetShell, Writefile from ..types_ import BashCommand @@ -143,10 +142,21 @@ def _get_exit_code() -> int: CWD = os.getcwd() +def reset_shell() -> str: + global SHELL, BASH_STATE, CWD + SHELL.close(True) + SHELL = start_shell() + BASH_STATE = "repl" + CWD = os.getcwd() + return "Reset successful" + get_status() + + WAITING_INPUT_MESSAGE = """A command is already running. NOTE: You can't run multiple shell sessions, likely a previous program hasn't exited. 1. Get its output using `send_ascii: [10] or send_specials: ["Enter"]` 2. Use `send_ascii` or `send_specials` to give inputs to the running program, don't use `BashCommand` OR -3. kill the previous program by sending ctrl+c first using `send_ascii` or `send_specials`""" +3. kill the previous program by sending ctrl+c first using `send_ascii` or `send_specials` +4. Send the process in background using `send_specials: ["Ctrl-z"]` followed by BashCommand: `bg` +""" def update_repl_prompt(command: str) -> bool: @@ -378,14 +388,9 @@ def wrapper(*args: Param.args, **kwargs: Param.kwargs) -> T: return wrapper -@ensure_no_previous_output def read_image_from_shell(file_path: str) -> ImageData: if not os.path.isabs(file_path): - SHELL.sendline("pwd") - SHELL.expect(PROMPT) - assert isinstance(SHELL.before, str) - current_dir = render_terminal_output(SHELL.before).strip() - file_path = os.path.join(current_dir, file_path) + file_path = os.path.join(CWD, file_path) if not os.path.exists(file_path): raise ValueError(f"File {file_path} does not exist") @@ -411,6 +416,54 @@ def write_file(writefile: Writefile) -> str: return "Success" +def find_least_edit_distance_substring(content: str, find_str: str) -> str: + content_lines = content.split("\n") + find_lines = find_str.split("\n") + # Slide window and find one with sum of edit distance least + min_edit_distance = float("inf") + min_edit_distance_lines = [] + for i in range(len(content_lines) - len(find_lines) + 1): + edit_distance_sum = 0 + for j in range(len(find_lines)): + edit_distance_sum += edit_distance(content_lines[i + j], find_lines[j]) + if edit_distance_sum < min_edit_distance: + min_edit_distance = edit_distance_sum + min_edit_distance_lines = content_lines[i : i + len(find_lines)] + return "\n".join(min_edit_distance_lines) + + +def file_edit(file_edit: FileEditFindReplace) -> str: + if not os.path.isabs(file_edit.file_path): + path_ = os.path.join(CWD, file_edit.file_path) + else: + path_ = file_edit.file_path + + out_string = "\n".join("> " + line for line in file_edit.find_lines.split("\n")) + in_string = "\n".join( + "< " + line for line in file_edit.replace_with_lines.split("\n") + ) + console.log(f"Editing file: {path_}---\n{out_string}\n---{in_string}\n---") + try: + with open(path_) as f: + content = f.read() + # First find counts + count = content.count(file_edit.find_lines) + + if count == 0: + closest_match = find_least_edit_distance_substring( + content, file_edit.find_lines + ) + return f"Error: no match found for the provided `find_lines` in the file. Closest match:\n---\n{closest_match}\n---\nFile not edited" + + content = content.replace(file_edit.find_lines, file_edit.replace_with_lines) + with open(path_, "w") as f: + f.write(content) + except OSError as e: + return f"Error: {e}" + console.print(f"File written to {path_}") + return "Success" + + class DoneFlag(BaseModel): task_output: str @@ -438,7 +491,9 @@ def which_tool(args: str) -> BaseModel: Confirmation | BashCommand | BashInteraction + | ResetShell | Writefile + | FileEditFindReplace | AIAssistant | DoneFlag | ReadImage @@ -446,7 +501,9 @@ def which_tool(args: str) -> BaseModel: Confirmation | BashCommand | BashInteraction + | ResetShell | Writefile + | FileEditFindReplace | AIAssistant | DoneFlag | ReadImage @@ -459,7 +516,9 @@ def get_tool_output( | Confirmation | BashCommand | BashInteraction + | ResetShell | Writefile + | FileEditFindReplace | AIAssistant | DoneFlag | ReadImage, @@ -473,7 +532,9 @@ def get_tool_output( Confirmation | BashCommand | BashInteraction + | ResetShell | Writefile + | FileEditFindReplace | AIAssistant | DoneFlag | ReadImage @@ -481,7 +542,9 @@ def get_tool_output( Confirmation | BashCommand | BashInteraction + | ResetShell | Writefile + | FileEditFindReplace | AIAssistant | DoneFlag | ReadImage @@ -499,6 +562,9 @@ def get_tool_output( elif isinstance(arg, Writefile): console.print("Calling write file tool") output = write_file(arg), 0 + elif isinstance(arg, FileEditFindReplace): + console.print("Calling file edit tool") + output = file_edit(arg), 0.0 elif isinstance(arg, DoneFlag): console.print("Calling mark finish tool") output = mark_finish(arg), 0.0 @@ -508,6 +574,9 @@ def get_tool_output( elif isinstance(arg, ReadImage): console.print("Calling read image tool") output = read_image_from_shell(arg.file_path), 0.0 + elif isinstance(arg, ResetShell): + console.print("Calling reset shell tool") + output = reset_shell(), 0.0 else: raise ValueError(f"Unknown tool: {arg}") @@ -524,7 +593,7 @@ def get_tool_output( class Mdata(BaseModel): - data: BashCommand | BashInteraction | Writefile + data: BashCommand | BashInteraction | Writefile | ResetShell | FileEditFindReplace execution_lock = threading.Lock() diff --git a/src/wcgw/relay/serve.py b/src/wcgw/relay/serve.py index ecc5e55..d90a502 100644 --- a/src/wcgw/relay/serve.py +++ b/src/wcgw/relay/serve.py @@ -14,11 +14,18 @@ from dotenv import load_dotenv -from ..types_ import BashCommand, BashInteraction, Writefile, Specials +from ..types_ import ( + BashCommand, + BashInteraction, + FileEditFindReplace, + ResetShell, + Writefile, + Specials, +) class Mdata(BaseModel): - data: BashCommand | BashInteraction | Writefile + data: BashCommand | BashInteraction | Writefile | ResetShell | FileEditFindReplace user_id: UUID @@ -148,6 +155,75 @@ def put_results(result: str) -> None: raise fastapi.HTTPException(status_code=500, detail="Timeout error") +class FileEditFindReplaceWithUUID(FileEditFindReplace): + user_id: UUID + + +@app.post("/v1/file_edit_find_replace") +async def file_edit_find_replace( + file_edit_find_replace: FileEditFindReplaceWithUUID, +) -> str: + user_id = file_edit_find_replace.user_id + if user_id not in clients: + raise fastapi.HTTPException( + status_code=404, detail="User with the provided id not found" + ) + + results: Optional[str] = None + + def put_results(result: str) -> None: + nonlocal results + results = result + + gpts[user_id] = put_results + + await clients[user_id]( + Mdata( + data=file_edit_find_replace, + user_id=user_id, + ) + ) + + start_time = time.time() + while time.time() - start_time < 30: + if results is not None: + return results + await asyncio.sleep(0.1) + + raise fastapi.HTTPException(status_code=500, detail="Timeout error") + + +class ResetShellWithUUID(ResetShell): + user_id: UUID + + +@app.post("/v1/reset_shell") +async def reset_shell(reset_shell: ResetShellWithUUID) -> str: + user_id = reset_shell.user_id + if user_id not in clients: + raise fastapi.HTTPException( + status_code=404, detail="User with the provided id not found" + ) + + results: Optional[str] = None + + def put_results(result: str) -> None: + nonlocal results + results = result + + gpts[user_id] = put_results + + await clients[user_id](Mdata(data=reset_shell, user_id=user_id)) + + start_time = time.time() + while time.time() - start_time < 30: + if results is not None: + return results + await asyncio.sleep(0.1) + + raise fastapi.HTTPException(status_code=500, detail="Timeout error") + + @app.post("/execute_bash") async def execute_bash_deprecated(excute_bash_data: Any, user_id: UUID) -> Response: return Response( diff --git a/src/wcgw/types_.py b/src/wcgw/types_.py index 39532ad..b05e38b 100644 --- a/src/wcgw/types_.py +++ b/src/wcgw/types_.py @@ -25,3 +25,13 @@ class ReadImage(BaseModel): class Writefile(BaseModel): file_path: str file_content: str + + +class FileEditFindReplace(BaseModel): + file_path: str + find_lines: str + replace_with_lines: str + + +class ResetShell(BaseModel): + should_reset: Literal[True] = True diff --git a/uv.lock b/uv.lock index 76de605..7a6f23f 100644 --- a/uv.lock +++ b/uv.lock @@ -340,6 +340,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/b2/bd6665030f7d7cd5d9182c62a869c3d5ceadd7bff9f1b305de9192e7dbf8/jiter-0.6.1-cp312-none-win_amd64.whl", hash = "sha256:91e63273563401aadc6c52cca64a7921c50b29372441adc104127b910e98a5b6", size = 198966 }, ] +[[package]] +name = "joblib" +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -411,6 +420,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, ] +[[package]] +name = "nltk" +version = "3.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 }, +] + [[package]] name = "openai" version = "1.51.2" @@ -911,6 +935,7 @@ source = { editable = "." } dependencies = [ { name = "fastapi" }, { name = "mypy" }, + { name = "nltk" }, { name = "openai" }, { name = "petname" }, { name = "pexpect" }, @@ -941,6 +966,7 @@ dev = [ requires-dist = [ { name = "fastapi", specifier = ">=0.115.0" }, { name = "mypy", specifier = ">=1.11.2" }, + { name = "nltk", specifier = ">=3.9.1" }, { name = "openai", specifier = ">=1.46.0" }, { name = "petname", specifier = ">=2.6" }, { name = "pexpect", specifier = ">=4.9.0" },