Skip to content

Commit

Permalink
Add a distributed setup using GRPC
Browse files Browse the repository at this point in the history
In this setup the different roles of the processing are running in
different processes. These processes are calling into a task manager
in order to communicate with each other vai state transition (encoded in
the `status` of the `Task` messages).

The clients are mock implementations, and only contain example of the
GRPC calls.

How to try it:

  * Create a venv
  * Install the packages defined in `requirements.txt`
  * Run `build.sh` to generate the message and service stubs
  * Start the task manager first which fires up a GRPC server storing
    tasks
  * Run the client keyboard listener. This contains a mock regarding
    what the keyboard listener should call once the OCR complete
  * Run the llm worker to show an example of how the llm processing
    should happen. Pro tip: To increase parallelism, one can start more
    of these. This takes tasks with `question` and should call the
    LLM to fill out the `answer`.
  * The UI is an example of how to access processed (answered) tasks.
  • Loading branch information
ormandi committed Dec 21, 2023
1 parent 0e60a9d commit 274a65e
Show file tree
Hide file tree
Showing 10 changed files with 586 additions and 91 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
**/__pycache__/**
**/*.log

aihub/aihub_pb2.py
aihub/aihub_pb2.pyi
aihub/aihub_pb2_grpc.py
39 changes: 39 additions & 0 deletions aihub.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
syntax = "proto3";

import "google/protobuf/empty.proto";

package aihub;

service AIHub {
// 0 -> NEW
// Assumed only the field `question` is peresnt in the input task.
// The returned task has `status` set to `NEW` and `id` is present.
rpc AddNewTask (Task) returns (Task) {}

// NEW -> GENERATING_ANSWER
// Assumed `id` and `question` are present, `status` is set to
// `GENERATING_ANSWER` in the returned task.
rpc StartGeneratingAnswer(google.protobuf.Empty) returns (Task) {}

// GENERATING_ANSWER -> ANSWER_AVAILABLE
// Assumed the `id` and the `answer` is present in the input task.
// The returned task has `status` set to `ANSWER_AVAILABLE`.
rpc AddAnswer (Task) returns (Task) {}

// ANSWER_AVAILABLE -> 0
// Assumed `status` is set to `ANSWER_AVAILABLE`.
rpc RemoveProcessedQuestion(google.protobuf.Empty) returns (Task) {}
}

enum TaskStatus {
NEW = 0;
GENERATING_ANSWER = 1;
ANSWER_AVAILABLE = 2;
}

message Task {
int64 id = 1;
TaskStatus status = 2;
string question = 3;
string answer = 4;
}
109 changes: 66 additions & 43 deletions aihub/aihub.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,37 @@
import multiprocessing


version = '1.0 beta'
version = "1.0 beta"
COMBINATIONS = [
{keyboard.Key.shift, keyboard.Key.f1},
{keyboard.Key.shift, keyboard.Key.f12}
{keyboard.Key.shift, keyboard.Key.f12},
]
current = set()


class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"


class ClickListener:
def __init__(self):
self.coordinates = None
self.keep_listening = True


def on_click(self, x, y, button, pressed):
if pressed:
self.coordinates = int(x), int(y)
self.keep_listening = False
return False # Stop the listener


def get_coordinates(self):
with Listener(on_click=self.on_click) as listener:
listener.join() # This will block until on_click returns False
Expand All @@ -51,30 +50,30 @@ def get_coordinates(self):

def system_printer(text, is_gui=False):
if is_gui:
print(f'\n{text}')
print(f"\n{text}")
else:
print(f'\n{bcolors.HEADER}{text}{bcolors.ENDC}')
print(f"\n{bcolors.HEADER}{text}{bcolors.ENDC}")


def user_printer(text, is_gui=False):
if is_gui:
print(f'\n{text}')
print(f"\n{text}")
else:
print(f'\n{bcolors.WARNING}{text}{bcolors.ENDC}')
print(f"\n{bcolors.WARNING}{text}{bcolors.ENDC}")


def bot_printer(text, is_gui=False):
if is_gui:
print(f'\n{text}')
print(f"\n{text}")
else:
print(f'\n{bcolors.OKGREEN}{text}{bcolors.ENDC}')
print(f"\n{bcolors.OKGREEN}{text}{bcolors.ENDC}")


def spinner():
symbols = ['-', '\\', '|', '/']
symbols = ["-", "\\", "|", "/"]
i = 0
while True:
sys.stdout.write('\r' + symbols[i])
sys.stdout.write("\r" + symbols[i])
sys.stdout.flush()
time.sleep(0.1)
i = (i + 1) % len(symbols)
Expand All @@ -88,15 +87,15 @@ def capture_screen():
x1, y1 = click_listener.get_coordinates()
x2, y2 = click_listener.get_coordinates()

logging.info(f'Rectangle: x1: {x1} y1: {y1} x2: {x2} y2: {y2}')
logging.info(f"Rectangle: x1: {x1} y1: {y1} x2: {x2} y2: {y2}")
try:
image = ImageGrab.grab(bbox=(x1, y1, x2, y2))
except:
return None
return image


def send_request(user_input:str, api:str):
def send_request(user_input: str, api: str):
url = f"{api}/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
Expand All @@ -106,22 +105,22 @@ def send_request(user_input:str, api:str):
"stop": ["### Instruction:"],
"temperature": 0.7,
"max_tokens": -1,
"stream": False
"stream": False,
}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()['choices'][0]['message']['content']
return response.json()["choices"][0]["message"]["content"]


def help_me(api:str, prompt_prefix:str, is_gui=False):
def help_me(api: str, prompt_prefix: str, is_gui=False):
global spinner_thread
image = capture_screen()
if image != None:
text = pytesseract.image_to_string(image)
prompt = f'{prompt_prefix}\n {text}'
user_printer('USER:', is_gui)
prompt = f"{prompt_prefix}\n {text}"
user_printer("USER:", is_gui)
user_printer(prompt, is_gui)
bot_printer('BOT:', is_gui)
bot_printer("BOT:", is_gui)
if not is_gui:
spinner_process = multiprocessing.Process(target=spinner)
spinner_process.start()
Expand All @@ -130,24 +129,42 @@ def help_me(api:str, prompt_prefix:str, is_gui=False):
spinner_process.terminate()
spinner_process.join()


return bot_response
pass


def main():
parser = argparse.ArgumentParser(description='aiHub arg parse')
parser.add_argument('-api', '--llm_api_host', type=str, help='LLM API host', default='http://localhost:1234/v1')
parser.add_argument('-pp', '--prompt_prefix', type=str, help='Prefix for every prompt', default='Help me with this')
parser.add_argument('-l', '--log_file', type=str, help='Log file for aiHub', default='aihub.log')
parser.add_argument('-gui', '--gui_printer', action='store_true', help='Using aiHub with aiHubManager GUI')
parser = argparse.ArgumentParser(description="aiHub arg parse")
parser.add_argument(
"-api",
"--llm_api_host",
type=str,
help="LLM API host",
default="http://localhost:1234/v1",
)
parser.add_argument(
"-pp",
"--prompt_prefix",
type=str,
help="Prefix for every prompt",
default="Help me with this",
)
parser.add_argument(
"-l", "--log_file", type=str, help="Log file for aiHub", default="aihub.log"
)
parser.add_argument(
"-gui",
"--gui_printer",
action="store_true",
help="Using aiHub with aiHubManager GUI",
)
args = parser.parse_args()

system_printer(f'I am aiHub | version: {version} | devquasar.com', args.gui_printer)
system_printer(f"I am aiHub | version: {version} | devquasar.com", args.gui_printer)

def log_setup():
log_handler = logging.handlers.WatchedFileHandler(args.log_file)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
formatter.converter = time.gmtime # if you want UTC time
log_handler.setFormatter(formatter)
logger = logging.getLogger()
Expand All @@ -157,17 +174,22 @@ def log_setup():
def on_press(key):
if any([key in COMBO for COMBO in COMBINATIONS]):
current.add(key)
logging.info(f'key combo pressed: {current}')
logging.info(f"key combo pressed: {current}")
# if any(all(k in current for k in COMBO) for COMBO in COMBINATIONS):
# if any(x in current for x in {keyboard.KeyCode(char='A')}):
if any(keypressed in current for keypressed in {keyboard.Key.f1}):
system_printer(f'Shift + F1 pressed.\nMake a screenshot: define the screen area by click 2 corners of a rectangle', args.gui_printer)

bot_response = help_me(args.llm_api_host, args.prompt_prefix, args.gui_printer)
system_printer(
f"Shift + F1 pressed.\nMake a screenshot: define the screen area by click 2 corners of a rectangle",
args.gui_printer,
)

bot_response = help_me(
args.llm_api_host, args.prompt_prefix, args.gui_printer
)
bot_printer(bot_response, args.gui_printer)
# if key == keyboard.Key.shift and any(x in current for x in {keyboard.KeyCode(char='X')}):
if any(keypressed in current for keypressed in {keyboard.Key.f12}):
logging.info('Shift + F12 pressed. Stopping the listener.')
logging.info("Shift + F12 pressed. Stopping the listener.")
listener.stop()

def on_release(key):
Expand All @@ -181,5 +203,6 @@ def on_release(key):
with keyboard.Listener(on_press=on_press, on_release=on_release) as listener:
listener.join()


if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 274a65e

Please sign in to comment.